Commit 76c182d9 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev_wm' into 'v0.7.2-dev'

[feat]适配0.7.2版本的deepseek_v3 nn layout

See merge request dcutoolkit/deeplearing/vllm!63
parents 99d49945 70deb5ef
# SPDX-License-Identifier: Apache-2.0 import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple from typing import Any, Dict, Generic, List, Optional, Tuple
...@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O): if is_fp8(self.W_UV_O):
...@@ -301,6 +302,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -301,6 +302,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert self.o_proj.weight.dtype == weight_dtype assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype assert self.q_proj.weight.dtype == weight_dtype
if self.use_llama_nn and self.kv_b_proj.quant_method is None:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
assert kv_b_proj_weight.shape == (
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.kv_lora_rank), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
...@@ -319,6 +331,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -319,6 +331,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if self.use_llama_nn and self.q_proj.quant_method is None:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj)\
.view(-1, self.num_heads, self.qk_head_dim)
else:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim) .view(-1, self.num_heads, self.qk_head_dim)
...@@ -376,6 +392,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -376,6 +392,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
else: else:
self.W_Q_UK = W_Q_UK.to(act_dtype) self.W_Q_UK = W_Q_UK.to(act_dtype)
if self.use_llama_nn and self.o_proj.quant_method is None:
W_O = get_and_maybe_dequant_weights(self.o_proj).T\
.view(-1, self.num_heads, self.v_head_dim)
else:
W_O = get_and_maybe_dequant_weights(self.o_proj)\ W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim) .view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
......
...@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj", "mlp.down_proj",
"shared_experts.gate_up_proj", "shared_experts.gate_up_proj",
"shared_experts.down_proj" "shared_experts.down_proj",
]
if not self.use_mla:
lay_key_words.extend([
"self_attn.q_proj.weight", "self_attn.q_proj.weight",
"self_attn.q_b_proj.weight", "self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight", "self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
]) "lm_head.weight"
]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment