Commit 92d43fd5 authored by zhuwenwen's avatar zhuwenwen
Browse files

update llama and qwen2 layout

parent 8f8d56c3
......@@ -477,7 +477,9 @@ class LlamaModel(nn.Module):
# lay_qkv_words = ["self_attn.qkv_proj.weight"]
# qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
# for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
......@@ -512,7 +514,8 @@ class LlamaModel(nn.Module):
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
......
......@@ -455,7 +455,9 @@ class Qwen2Model(nn.Module):
# lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
# for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
......@@ -491,7 +493,8 @@ class Qwen2Model(nn.Module):
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment