"vscode:/vscode.git/clone" did not exist on "afb4429b4f13e744b1630b6c5a09156e5b1ececc"
Commit 8781f412 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_MOE_add_RQ_push1' into 'v0.15.1-dev'

perf: DS V2模型MOE部分增加rmsQuant

See merge request dcutoolkit/deeplearing/vllm!492
parents a56e3da7 168ceef7
...@@ -667,7 +667,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -667,7 +667,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True, use_mla=True,
use_sparse=use_sparse, use_sparse=use_sparse,
) )
if ( if (
cache_config is not None cache_config is not None
and cache_config.enable_prefix_caching and cache_config.enable_prefix_caching
......
...@@ -1507,6 +1507,8 @@ def outplace_fused_experts( ...@@ -1507,6 +1507,8 @@ def outplace_fused_experts(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -1536,6 +1538,8 @@ def outplace_fused_experts( ...@@ -1536,6 +1538,8 @@ def outplace_fused_experts(
w1_bias, w1_bias,
w2_bias, w2_bias,
use_nn_moe, use_nn_moe,
i_q=i_q,
i_s=i_s,
) )
...@@ -1614,7 +1618,7 @@ def fused_experts( ...@@ -1614,7 +1618,7 @@ def fused_experts(
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None # TODO:wjl i_s: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
...@@ -1646,6 +1650,8 @@ def fused_experts( ...@@ -1646,6 +1650,8 @@ def fused_experts(
w1_bias=quant_config.w1_bias, w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias, w2_bias=quant_config.w2_bias,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s,
) )
...@@ -1705,6 +1711,8 @@ def fused_experts_impl( ...@@ -1705,6 +1711,8 @@ def fused_experts_impl(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
...@@ -1851,7 +1859,9 @@ def fused_experts_impl( ...@@ -1851,7 +1859,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False,
i_q=i_q,
i_s=i_s
) )
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
......
...@@ -1976,7 +1976,7 @@ class FusedMoE(CustomOp): ...@@ -1976,7 +1976,7 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
# NOTE: Similar with DP, PCP also needs dispatch and combine. For # NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe # simplicity, AgRsAll2All was added separately for PCP here. Maybe
...@@ -2014,7 +2014,6 @@ class FusedMoE(CustomOp): ...@@ -2014,7 +2014,6 @@ class FusedMoE(CustomOp):
self.capture(topk_ids) self.capture(topk_ids)
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
......
...@@ -370,7 +370,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -370,7 +370,8 @@ class DeepseekV2MoE(nn.Module):
if self.experts.is_internal_router: if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class # In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states hidden_states=hidden_states, router_logits=hidden_states,
iqis=iqis
) )
else: else:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
...@@ -1087,7 +1088,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1087,7 +1088,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # TODO:wjl iqis=(_i_q, _i_s)
) )
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment