Commit 122d7146 authored by zhuwenwen's avatar zhuwenwen
Browse files

update the lm-head layout of llama and qwen series models

parent 53910677
......@@ -455,7 +455,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
# self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
......@@ -575,7 +575,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
......@@ -584,10 +583,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
......
......@@ -482,7 +482,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
......@@ -493,6 +492,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment