Commit 99d49945 authored by zhuwenwen's avatar zhuwenwen
Browse files

update model layout

parent 92d43fd5
......@@ -79,8 +79,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# 'Qwen2VLForConditionalGeneration'
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV3ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
......
......@@ -499,7 +499,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# lay_qkv_words = ["self_attn.W_pack.weight"]
# qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
......@@ -526,7 +527,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
......
......@@ -419,7 +419,8 @@ class BloomForCausalLM(nn.Module, SupportsPP):
# lay_qkv_bias_words = ["self_attention.query_key_value.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
......
......@@ -691,7 +691,8 @@ class ChatGLMModel(nn.Module):
# lay_qkv_bias_words = ["self_attention.query_key_value.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] == 4096:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
......
......@@ -832,7 +832,8 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
])
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
......
......@@ -562,7 +562,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
# lay_qkv_words = ["self_attention.query_key_value.weight"]
# qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
......
......@@ -522,7 +522,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
......
......@@ -1109,7 +1109,8 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
# lay_qkv_bias_words = ["attn.c_attn.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
......@@ -1139,7 +1140,8 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
......@@ -1182,7 +1184,8 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
......
......@@ -537,7 +537,8 @@ class Qwen2Model(nn.Module):
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
......
......@@ -575,7 +575,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
# lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
......
......@@ -685,7 +685,8 @@ class Qwen2VisionTransformer(nn.Module):
# lay_qkv_bias_words = ["attn.qkv.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
for layername in loaded_params:
weight = params_dict[layername]
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment