Commit ee9541af authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update tensor_parallel.py

parent 72501097
......@@ -65,7 +65,10 @@ class TensorParallelHead(SuperLayer):
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
if config.model_type == "baichuan":
weight = F.normalize(weight)
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment