"docs/vscode:/vscode.git/clone" did not exist on "b16fda62b70af82d11f6637810f40f79ff713a82"
Commit 89d1dd57 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Models]support blas and moe nn layout of deepseek-v3

parent 53076d70
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import os
import re
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP ...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -675,6 +678,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -675,6 +678,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -807,6 +815,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -807,6 +815,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment