Unverified Commit e028af69 authored by Xun Sun's avatar Xun Sun Committed by GitHub
Browse files

Fix mooncake dispatcher (#11908)

parent 80b2b320
...@@ -86,7 +86,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: ...@@ -86,7 +86,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend() a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none(): if a2a_backend.is_none():
return StandardDispatcher(moe_runner_config) return StandardDispatcher(moe_runner_config)
elif a2a_backend.is_deepep(): elif a2a_backend.is_deepep() or a2a_backend.is_mooncake():
return MaybeTboDeepEPDispatcher( return MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group, group=get_tp_group().device_group,
router_topk=moe_runner_config.top_k, router_topk=moe_runner_config.top_k,
......
...@@ -36,7 +36,7 @@ class MooncakeDispatchOutput(NamedTuple): ...@@ -36,7 +36,7 @@ class MooncakeDispatchOutput(NamedTuple):
"""Mooncake EP dispatch output.""" """Mooncake EP dispatch output."""
hidden_states: torch.Tensor hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor hidden_states_scale: Optional[torch.Tensor]
topk_ids: torch.Tensor topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
masked_m: torch.Tensor masked_m: torch.Tensor
...@@ -205,8 +205,14 @@ class _MooncakeEPDispatcherImpl: ...@@ -205,8 +205,14 @@ class _MooncakeEPDispatcherImpl:
masked_m masked_m
) )
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states
else:
hidden_states_scale = None
return MooncakeDispatchOutput( return MooncakeDispatchOutput(
hidden_states, hidden_states,
hidden_states_scale,
topk_ids, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment