Commit 76c182d9 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev_wm' into 'v0.7.2-dev'

[feat]适配0.7.2版本的deepseek_v3 nn layout

See merge request dcutoolkit/deeplearing/vllm!63
parents 99d49945 70deb5ef
# SPDX-License-Identifier: Apache-2.0
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
......@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
......@@ -301,6 +302,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
if self.use_llama_nn and self.kv_b_proj.quant_method is None:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
assert kv_b_proj_weight.shape == (
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.kv_lora_rank), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
......@@ -319,6 +331,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if self.use_llama_nn and self.q_proj.quant_method is None:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj)\
.view(-1, self.num_heads, self.qk_head_dim)
else:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
......@@ -376,6 +392,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
if self.use_llama_nn and self.o_proj.quant_method is None:
W_O = get_and_maybe_dequant_weights(self.o_proj).T\
.view(-1, self.num_heads, self.v_head_dim)
else:
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
......
......@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj"
]
if not self.use_mla:
lay_key_words.extend([
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
])
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment