Commit fca0956a authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_lightop_moe_sum_mul_add' into 'v0.15.1-dev'

feat(moe): 修复 shared_output 透传被覆盖并兼容 torch.compile 启动路径

See merge request dcutoolkit/deeplearing/vllm!517
parents 1ea9a3f0 eb933fe1
...@@ -2005,7 +2005,11 @@ class FusedMoE(CustomOp): ...@@ -2005,7 +2005,11 @@ class FusedMoE(CustomOp):
# Run shared experts before matrix multiply. # Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if (
has_separate_shared_experts
and not use_shared_experts_stream
and shared_output is None
):
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s)) shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
...@@ -2073,7 +2077,7 @@ class FusedMoE(CustomOp): ...@@ -2073,7 +2077,7 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts: if has_separate_shared_experts:
assert self.shared_experts is not None assert self.shared_experts is not None
if use_shared_experts_stream: if use_shared_experts_stream and shared_output is None:
# Run shared experts in parallel on a separate stream # Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the # NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is # sync end point immediately after it is done. This is
......
...@@ -67,7 +67,9 @@ class SharedFusedMoE(FusedMoE): ...@@ -67,7 +67,9 @@ class SharedFusedMoE(FusedMoE):
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if shared_output is not None:
shared_out = shared_output
elif self._shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
assert iqis[0] is not None assert iqis[0] is not None
assert iqis[1] is not None assert iqis[1] is not None
......
...@@ -377,16 +377,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -377,16 +377,18 @@ class DeepseekV2MoE(nn.Module):
and self.shared_experts is not None and self.shared_experts is not None
and not needs_post_moe_combine and not needs_post_moe_combine
): ):
router_logits, _ = self.gate(hidden_states)
shared_output = self.shared_experts(hidden_states, iqis=iqis) shared_output = self.shared_experts(hidden_states, iqis=iqis)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class.
router_logits = hidden_states
else:
router_logits, _ = self.gate(hidden_states)
routed_scaling_factor = ( routed_scaling_factor = (
1.0 if self.is_rocm_aiter_moe_enabled 1.0 if self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor else self.routed_scaling_factor
) )
self.experts.use_overlapped = False # Keep shared-expert path intact and only fuse routed scale + add
self.experts._shared_experts = None # in the downstream MoE kernel.
# Marlin W16A16 fused reduce consumes the precomputed
# shared_output and routed_scaling_factor directly.
_, final_hidden_states = self.experts( _, final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment