Unverified Commit 0afd6832 authored by Xun Sun's avatar Xun Sun Committed by GitHub
Browse files

Update Mooncake EP's a2a interface (#12391)

parent 6f858930
......@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
......@@ -18,6 +18,9 @@ from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.single_batch_overlap import CombineOverlapArgs
try:
from mooncake.mooncake_ep_buffer import Buffer
......@@ -234,13 +237,14 @@ class _MooncakeEPDispatcherImpl:
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, event, hook = self._combine_core(
hidden_states,
topk_ids,
topk_weights,
)
return hidden_states, event, hook
return hidden_states, event, hook, overlap_args
def combine_b(self, hidden_states, event, hook):
hook() if self.return_recv_hook else event.current_stream_wait()
......@@ -342,23 +346,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
def combine(
self,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
) -> Tuple:
self.combine_a(combine_input, overlap_args)
ret = self.combine_b()
return ret
def combine_a(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional = None,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, topk_ids, topk_weights = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a(
hidden_states=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
overlap_args=overlap_args,
)
self._combine_intermediate_state = inner_state
......
......@@ -962,9 +962,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and get_moe_a2a_backend().is_deepep()
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment