Commit 0b6a38bf authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' of http://10.6.10.68/dcutoolkit/deeplearing/vllm into v0.7.2-dev

parents 49bfe4cb 76c182d9
# SPDX-License-Identifier: Apache-2.0 import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple from typing import Any, Dict, Generic, List, Optional, Tuple
...@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -181,6 +180,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O): if is_fp8(self.W_UV_O):
...@@ -301,10 +302,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -301,10 +302,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert self.o_proj.weight.dtype == weight_dtype assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype assert self.q_proj.weight.dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T if self.use_llama_nn and self.kv_b_proj.quant_method is None:
assert kv_b_proj_weight.shape == ( kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
self.kv_lora_rank, assert kv_b_proj_weight.shape == (
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.kv_lora_rank), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, " f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, " f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, " f"{self.num_heads=}, "
...@@ -319,8 +331,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -319,8 +331,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ if self.use_llama_nn and self.q_proj.quant_method is None:
.view(-1, self.num_heads, self.qk_head_dim) q_proj_weight = get_and_maybe_dequant_weights(self.q_proj)\
.view(-1, self.num_heads, self.qk_head_dim)
else:
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if # can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend # q_lora_rank is None, the latter otherwise. From the Attention backend
...@@ -376,8 +392,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -376,8 +392,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
else: else:
self.W_Q_UK = W_Q_UK.to(act_dtype) self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\ if self.use_llama_nn and self.o_proj.quant_method is None:
.view(-1, self.num_heads, self.v_head_dim) W_O = get_and_maybe_dequant_weights(self.o_proj).T\
.view(-1, self.num_heads, self.v_head_dim)
else:
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous() .flatten(start_dim=0, end_dim=1).contiguous()
......
...@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -821,26 +821,18 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj", "mlp.down_proj",
"shared_experts.gate_up_proj", "shared_experts.gate_up_proj",
"shared_experts.down_proj" "shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
] ]
if not self.use_mla:
lay_key_words.extend([
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
])
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment