Commit 77ae0f0d authored by zhuwenwen's avatar zhuwenwen
Browse files

update lm_head layout of chatglm

parent 2ff1c360
...@@ -697,7 +697,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -697,7 +697,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"self_attention.dense.weight", "self_attention.dense.weight",
"mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight", "mlp.dense_4h_to_h.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
...@@ -708,6 +707,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -708,6 +707,12 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
qkv_bias_words = "|".join(lay_qkv_bias_words) qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
if "lm_head.weight" in layername and weight.shape[1] == 4096:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32) weight.data = pad_weight(weight.data, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment