Unverified Commit 0afd6832 authored by Xun Sun's avatar Xun Sun Committed by GitHub
Browse files

Update Mooncake EP's a2a interface (#12391)

parent 6f858930
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
...@@ -18,6 +18,9 @@ from sglang.srt.layers.moe.topk import TopKOutput ...@@ -18,6 +18,9 @@ from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.single_batch_overlap import CombineOverlapArgs
try: try:
from mooncake.mooncake_ep_buffer import Buffer from mooncake.mooncake_ep_buffer import Buffer
...@@ -234,13 +237,14 @@ class _MooncakeEPDispatcherImpl: ...@@ -234,13 +237,14 @@ class _MooncakeEPDispatcherImpl:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional[CombineOverlapArgs] = None,
): ):
hidden_states, event, hook = self._combine_core( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
topk_ids, topk_ids,
topk_weights, topk_weights,
) )
return hidden_states, event, hook return hidden_states, event, hook, overlap_args
def combine_b(self, hidden_states, event, hook): def combine_b(self, hidden_states, event, hook):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
...@@ -342,23 +346,27 @@ class MooncakeEPDispatcher(BaseDispatcher): ...@@ -342,23 +346,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state) return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple: def combine(
self.combine_a(*args, **kwargs) self,
combine_input: CombineInput,
overlap_args: Optional[CombineOverlapArgs] = None,
) -> Tuple:
self.combine_a(combine_input, overlap_args)
ret = self.combine_b() ret = self.combine_b()
return ret return ret
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, combine_input: CombineInput,
topk_ids: torch.Tensor, overlap_args: Optional[CombineOverlapArgs] = None,
topk_weights: torch.Tensor,
overlap_args: Optional = None,
): ):
hidden_states, topk_ids, topk_weights = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a( inner_state = self._get_impl().combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_ids=topk_ids, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
overlap_args=overlap_args,
) )
self._combine_intermediate_state = inner_state self._combine_intermediate_state = inner_state
......
...@@ -962,9 +962,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -962,9 +962,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_backend = get_moe_runner_backend() moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto(): if moe_runner_backend.is_auto():
if ( if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
and get_moe_a2a_backend().is_deepep()
): ):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment