Commit 795ce518 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix index error of baichuan2

parent 1be9a629
......@@ -398,25 +398,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment