Unverified Commit b23fb786 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Fix for 24530. Fix naive all2all shared expert overlap. (#24538)

parent 561f38dc
......@@ -1755,9 +1755,6 @@ class FusedMoE(CustomOp):
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
......@@ -1768,6 +1765,10 @@ class FusedMoE(CustomOp):
else:
shared_output = None
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
......@@ -1800,8 +1801,9 @@ class FusedMoE(CustomOp):
final_hidden_states,
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
......@@ -1810,10 +1812,11 @@ class FusedMoE(CustomOp):
return states
if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states)
else:
return (
reduce_output(final_hidden_states[0]),
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment