Commit 2350c778 authored by wujl5's avatar wujl5 Committed by wangmin6
Browse files

perf: DS V2模型MLA中增加rmsQuant

parent 3824b261
......@@ -379,9 +379,6 @@ def fused_rmsquant_fake(
dtype=torch.float32)
return output, scales
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op(
op_name="fused_rmsquant_customer_impl",
op_func=fused_rmsquant_impl,
......
......@@ -139,6 +139,14 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
if envs.USE_FUSED_RMS_QUANT:
qa_iq, qa_is, _ = self.q_a_layernorm(x=q_c,
residual=None,
quant_dtype=torch.int8,
update_input=False)
q = self.q_b_proj(q_c, iqis=(qa_iq, qa_is))[0]
else:
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
......
......@@ -811,6 +811,9 @@ class DeepseekV2MLAAttention(nn.Module):
)
if self.q_lora_rank is not None:
if envs.USE_FUSED_RMS_QUANT:
self.q_a_layernorm = FusedRMSNormQuant(self.q_lora_rank, eps=config.rms_norm_eps)
else:
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
self.q_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment