Commit b22a4a14 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_MLA_add_RQ_Push' into 'v0.15.1-dev'

perf: DS V2模型MLA中增加rmsQuant

See merge request dcutoolkit/deeplearing/vllm!487
parents fd831864 2350c778
...@@ -379,9 +379,6 @@ def fused_rmsquant_fake( ...@@ -379,9 +379,6 @@ def fused_rmsquant_fake(
dtype=torch.float32) dtype=torch.float32)
return output, scales return output, scales
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op( direct_register_custom_op(
op_name="fused_rmsquant_customer_impl", op_name="fused_rmsquant_customer_impl",
op_func=fused_rmsquant_impl, op_func=fused_rmsquant_impl,
......
...@@ -139,8 +139,16 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -139,8 +139,16 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1, dim=-1,
) )
q_c = self.q_a_layernorm(q_c) if envs.USE_FUSED_RMS_QUANT:
q = self.q_b_proj(q_c)[0] qa_iq, qa_is, _ = self.q_a_layernorm(x=q_c,
residual=None,
quant_dtype=torch.int8,
update_input=False)
q = self.q_b_proj(q_c, iqis=(qa_iq, qa_is))[0]
else:
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else: else:
assert self.kv_a_proj_with_mqa is not None, ( assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None" "kv_a_proj_with_mqa is required when q_lora_rank is None"
......
...@@ -811,7 +811,10 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -811,7 +811,10 @@ class DeepseekV2MLAAttention(nn.Module):
) )
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) if envs.USE_FUSED_RMS_QUANT:
self.q_a_layernorm = FusedRMSNormQuant(self.q_lora_rank, eps=config.rms_norm_eps)
else:
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
self.q_lora_rank, self.q_lora_rank,
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment