Commit 70deb5ef authored by 王敏's avatar 王敏
Browse files

[feat]适配0.7.2版本的deepseek_v3 nn layout

parent 99d49945
# SPDX-License-Identifier: Apache-2.0
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
......@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
......@@ -301,10 +302,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
if self.use_llama_nn and self.kv_b_proj.quant_method is None:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
assert kv_b_proj_weight.shape == (
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.kv_lora_rank), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
......@@ -319,8 +331,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
if self.use_llama_nn and self.q_proj.quant_method is None:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj)\
.view(-1, self.num_heads, self.qk_head_dim)
else:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
......@@ -376,8 +392,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
if self.use_llama_nn and self.o_proj.quant_method is None:
W_O = get_and_maybe_dequant_weights(self.o_proj).T\
.view(-1, self.num_heads, self.v_head_dim)
else:
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
......
......@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj"
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
if not self.use_mla:
lay_key_words.extend([
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
])
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment