Commit fca0956a authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_lightop_moe_sum_mul_add' into 'v0.15.1-dev'

feat(moe): 修复 shared_output 透传被覆盖并兼容 torch.compile 启动路径

See merge request dcutoolkit/deeplearing/vllm!517
parents 1ea9a3f0 eb933fe1
......@@ -2005,7 +2005,11 @@ class FusedMoE(CustomOp):
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
if (
has_separate_shared_experts
and not use_shared_experts_stream
and shared_output is None
):
assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output = self.shared_experts(hidden_states, iqis=(i_q, i_s))
......@@ -2073,7 +2077,7 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
if use_shared_experts_stream and shared_output is None:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
......
......@@ -67,7 +67,9 @@ class SharedFusedMoE(FusedMoE):
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
if shared_output is not None:
shared_out = shared_output
elif self._shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
assert iqis[0] is not None
assert iqis[1] is not None
......
......@@ -377,16 +377,18 @@ class DeepseekV2MoE(nn.Module):
and self.shared_experts is not None
and not needs_post_moe_combine
):
router_logits, _ = self.gate(hidden_states)
shared_output = self.shared_experts(hidden_states, iqis=iqis)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class.
router_logits = hidden_states
else:
router_logits, _ = self.gate(hidden_states)
routed_scaling_factor = (
1.0 if self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor
)
self.experts.use_overlapped = False
self.experts._shared_experts = None
# Marlin W16A16 fused reduce consumes the precomputed
# shared_output and routed_scaling_factor directly.
# Keep shared-expert path intact and only fuse routed scale + add
# in the downstream MoE kernel.
_, final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment