Commit c5888d31 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mtp layout

parent 7b681c9e
......@@ -90,7 +90,7 @@ def get_model_architecture(
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM']
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
# SPDX-License-Identifier: Apache-2.0
import os
import re
from typing import Iterable, List, Optional, Set, Tuple
import torch
......@@ -21,6 +23,7 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .utils import maybe_prefix
from vllm import _custom_ops as ops
class SharedHead(nn.Module):
......@@ -155,11 +158,20 @@ class DeepSeekMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_name()
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward(
self,
......@@ -262,6 +274,38 @@ class DeepSeekMTP(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.eh_proj.weight",
"self_attn.q_proj.weight",
"self_attn.q_a_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"mlp.gate.weight",
"shared_experts.gate_up_proj.weight",
"shared_experts.down_proj.weight",
"shared_head.head.weight",
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
......
......@@ -860,7 +860,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight.data=weight.data.reshape(ori_shape[1],-1)
if self.config.quantization_config["quant_method"] == "awq" and not envs.VLLM_USE_TRITON_AWQ:
if hasattr(self.config, "quantization_config") and self.config.quantization_config["quant_method"] == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [
"self_attn.q_a_proj.qweight",
"self_attn.q_b_proj.qweight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment