Unverified Commit 9c138a04 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[3/N] MoE Refactor: Simplify DeepEP Output (#8421)

parent c8f549d9
from __future__ import annotations
import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
......@@ -50,6 +52,13 @@ from sglang.srt.utils import (
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
)
_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
)
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
if self.deepep_mode.enable_low_latency():
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
......@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_batch: ForwardBatch,
):
dispatch_output = self.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
hidden_states = self.moe_impl(dispatch_output)
hidden_states = self.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
)
return hidden_states
def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
def moe_impl(self, dispatch_output: DispatchOutput):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_aiter(dispatch_output)
if dispatch_output.format.is_deepep_normal():
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
return self.forward_deepgemm_contiguous(dispatch_output)
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
return self.forward_normal(dispatch_output)
elif dispatch_output.format.is_deepep_ll():
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def forward_normal(
def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
def _prepare_for_normal(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
topk_idx: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_run_moe_deep_preprocess,
)
if hidden_states.shape[0] == 0:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
else:
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def forward_normal(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, topk_idx = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
)
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
hidden_states, topk_idx
)
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
......@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
def forward_aiter(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
)
if hidden_states.shape[0] == 0:
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
......@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_contiguous(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
dispatch_output: DeepEPNormalOutput,
):
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
......@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_masked(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
dispatch_output: DeepEPLLOutput,
):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"
......
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
......@@ -24,7 +44,6 @@ except ImportError:
use_deepep = False
from enum import Enum, IntEnum, auto
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__)
class DeepEPNormalOutput(NamedTuple):
"""DeepEP normal dispatch output."""
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int]
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal
class DeepEPLLOutput(NamedTuple):
"""DeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
LOW_LATENCY = auto()
......@@ -139,7 +189,7 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class DeepEPConfig:
class DeepEPConfig(BaseDispatcherConfig):
_instance = None
def __init__(self):
......@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
)
event.current_stream_wait() if self.async_finish else ()
return (
hidden_states,
topk_idx,
topk_weights,
None,
num_recv_tokens_per_expert_list,
None,
None,
None,
)
else:
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
masked_m = expected_m = None
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m,
expected_m,
)
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert,
event,
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else ()
return DeepEPNormalOutput(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
def _dispatch_core(
self,
......@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
num_recv_tokens_per_expert,
self.handle,
event,
) = buffer.dispatch(
......@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert_list,
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
......@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
num_recv_tokens_per_expert,
event,
)
def _deepep_permute(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
fp8_dtype: Optional[torch.dtype] = None,
use_fp8_w8a8: bool = False,
use_block_quant: bool = False,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_block_quant)
else hidden_states.dtype
),
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def combine_a(
self,
hidden_states: torch.Tensor,
......@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
)
reorder_topk_ids = seg_indptr = None
return (
return DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m,
expected_m,
)
......@@ -636,7 +587,7 @@ class _Stage(Enum):
AFTER_COMBINE_A = auto()
class DeepEPDispatcher:
class DeepEPDispatcher(BaseDispatcher):
def __init__(
self,
group: torch.distributed.ProcessGroup,
......@@ -676,7 +627,7 @@ class DeepEPDispatcher:
self._stage = _Stage.INITIAL
def dispatch(self, *args, **kwargs) -> Tuple:
def dispatch(self, *args, **kwargs) -> DispatchOutput:
self.dispatch_a(*args, **kwargs)
ret = self.dispatch_b()
return ret
......
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
import torch
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()
deepep_ll = auto()
def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard
def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal
def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll
@runtime_checkable
class DispatchOutput(Protocol):
"""Protocol for dispatch outputs in different formats."""
@property
def format(self) -> DispatchOutputFormat: ...
class BaseDispatcherConfig(ABC):
"""Base class for dispatcher configs."""
pass
class BaseDispatcher(ABC):
"""Base class for dispatchers."""
@abstractmethod
def dispatch(self, *args, **kwargs) -> DispatchOutput:
pass
@abstractmethod
def combine(self, *args, **kwargs) -> torch.Tensor:
pass
from __future__ import annotations
from typing import NamedTuple
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
DispatchOutput,
DispatchOutputFormat,
)
class StandardDispatchOutput(NamedTuple):
"""Standard dispatch output."""
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard
assert isinstance(StandardDispatchOutput, DispatchOutput)
......@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
if shared_output is not None:
x = shared_output
......@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
def op_dispatch_a(self, state):
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
self.experts.deepep_dispatcher.dispatch_a(
hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
......@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_batch=state.forward_batch,
state.hidden_states_experts_output = self.experts.moe_impl(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.deepep_dispatcher.combine_a(
self.experts.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
topk_idx=state.dispatch_output.topk_idx,
topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
def op_combine_b(self, state):
if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
state.hidden_states_after_combine = (
self.experts.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
)
def op_output(self, state):
......
......@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
self.top_k = config.num_experts_per_tok
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
......@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
return final_hidden_states
def op_gate(self, state):
......@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def op_dispatch_a(self, state):
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self.deepep_dispatcher.dispatch_a(
self.experts.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
......@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
state.hidden_states_experts_output = self.experts(
hidden_states=state.pop("hidden_states_experts_input"),
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_batch=state.forward_batch,
state.hidden_states_experts_output = self.experts.moe_impl(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.deepep_dispatcher.combine_a(
self.experts.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
topk_idx=state.dispatch_output.topk_idx,
topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
def op_combine_b(self, state):
if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
state.hidden_states_after_combine = (
self.experts.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
)
def op_output(self, state):
......
from __future__ import annotations
import dataclasses
import logging
from dataclasses import replace
from typing import Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
import torch
......@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__)
......@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
def dispatch(self, **kwargs):
def dispatch(self, **kwargs) -> DispatchOutput:
return self._execute("dispatch", **kwargs)
def dispatch_a(self, **kwargs):
......@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
def dispatch_b(self, **kwargs):
return self._execute("dispatch_b", **kwargs)
def combine(self, **kwargs):
def combine(self, **kwargs) -> torch.Tensor:
return self._execute("combine", **kwargs)
def combine_a(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment