Commit 168ceef7 authored by wujl5's avatar wujl5
Browse files

perf: DS V2模型MOE部分增加rmsQuant

parent a56e3da7
......@@ -667,7 +667,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
......
......@@ -1507,6 +1507,8 @@ def outplace_fused_experts(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -1536,6 +1538,8 @@ def outplace_fused_experts(
w1_bias,
w2_bias,
use_nn_moe,
i_q=i_q,
i_s=i_s,
)
......@@ -1614,7 +1618,7 @@ def fused_experts(
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None # TODO:wjl
i_s: torch.Tensor | None = None
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......@@ -1646,6 +1650,8 @@ def fused_experts(
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias,
use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s,
)
......@@ -1705,6 +1711,8 @@ def fused_experts_impl(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor:
# Check constraints.
num_tokens = hidden_states.size(0)
......@@ -1851,7 +1859,9 @@ def fused_experts_impl(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False
use_nn_moe=False,
i_q=i_q,
i_s=i_s
)
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
......
......@@ -1976,7 +1976,7 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
......@@ -2014,7 +2014,6 @@ class FusedMoE(CustomOp):
self.capture(topk_ids)
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
......
......@@ -370,7 +370,8 @@ class DeepseekV2MoE(nn.Module):
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
hidden_states=hidden_states, router_logits=hidden_states,
iqis=iqis
)
else:
# router_logits: (num_tokens, n_experts)
......@@ -1087,7 +1088,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # TODO:wjl
iqis=(_i_q, _i_s)
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment