Commit 89d1dd57 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Models]support blas and moe nn layout of deepseek-v3

parent 53076d70
......@@ -24,6 +24,8 @@
"""Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import os
import re
import torch
from torch import nn
from transformers import PretrainedConfig
......@@ -56,6 +58,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV2MLP(nn.Module):
......@@ -675,6 +678,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -807,6 +815,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment