Commit 644ecdd9 authored by wangkx1's avatar wangkx1
Browse files

Update tensor_parallel.py

parent 55e2c496
...@@ -66,6 +66,9 @@ class TensorParallelHead(SuperLayer): ...@@ -66,6 +66,9 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
if config.model_type == "baichuan":
weight = F.normalize(weight)
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None), get_linear(weight, bias=None),
process_group=weights.process_group, process_group=weights.process_group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment