Commit a8b2d878 authored by zhuwenwen's avatar zhuwenwen
Browse files

support deepseek_v2 nn layout

parent 211835ef
...@@ -80,7 +80,7 @@ def get_model_architecture( ...@@ -80,7 +80,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV3ForCausalLM'] 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
...@@ -843,17 +843,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -843,17 +843,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if self.use_llama_nn and self.quant_method is None: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.q_a_proj.weight", "self_attn.q_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight", "self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight", "self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj", "mlp.down_proj",
"mlp.gate.weight",
"shared_experts.gate_up_proj", "shared_experts.gate_up_proj",
"shared_experts.down_proj", "shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight" "lm_head.weight"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment