"docs/vscode:/vscode.git/clone" did not exist on "d50ce994213a264dfb746cd5e4ebc0f148f03b17"
Unverified Commit 9c138a04 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[3/N] MoE Refactor: Simplify DeepEP Output (#8421)

parent c8f549d9
from __future__ import annotations
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
...@@ -50,6 +52,13 @@ from sglang.srt.utils import ( ...@@ -50,6 +52,13 @@ from sglang.srt.utils import (
next_power_of_2, next_power_of_2,
) )
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE): ...@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable." "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
) )
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
if self.deepep_mode.enable_low_latency(): if self.deepep_mode.enable_low_latency():
assert ( assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
...@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE): ...@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
dispatch_output = self.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
hidden_states = self.moe_impl(dispatch_output)
hidden_states = self.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
)
return hidden_states
def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
def moe_impl(self, dispatch_output: DispatchOutput):
if _use_aiter: if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights) return self.forward_aiter(dispatch_output)
resolved_deepep_mode = self.deepep_mode.resolve( if dispatch_output.format.is_deepep_normal():
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous( return self.forward_deepgemm_contiguous(dispatch_output)
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
else: else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) return self.forward_normal(dispatch_output)
elif resolved_deepep_mode == DeepEPMode.low_latency: elif dispatch_output.format.is_deepep_ll():
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) return self.forward_deepgemm_masked(dispatch_output)
else: else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def forward_normal( def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
def _prepare_for_normal(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor, topk_idx: torch.Tensor,
seg_indptr: torch.Tensor, ):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_run_moe_deep_preprocess,
)
if hidden_states.shape[0] == 0:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
else:
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def forward_normal(
self,
dispatch_output: DeepEPNormalOutput,
): ):
hidden_states, topk_idx = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
)
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
hidden_states, topk_idx
)
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
...@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE): ...@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
def forward_aiter( def forward_aiter(
self, self,
hidden_states: torch.Tensor, dispatch_output: DeepEPNormalOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
): ):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
)
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
return hidden_states return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed. # in original deepep, idx == -1 meaning invalid and will not be processed.
...@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE): ...@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], dispatch_output: DeepEPNormalOutput,
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
): ):
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8 hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.activation == "silu"
...@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE): ...@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
def forward_deepgemm_masked( def forward_deepgemm_masked(
self, self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], dispatch_output: DeepEPLLOutput,
masked_m: torch.Tensor,
expected_m: int,
): ):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.activation == "silu"
......
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -24,7 +44,6 @@ except ImportError: ...@@ -24,7 +44,6 @@ except ImportError:
use_deepep = False use_deepep = False
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() ...@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeepEPNormalOutput(NamedTuple):
"""DeepEP normal dispatch output."""
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int]
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal
class DeepEPLLOutput(NamedTuple):
"""DeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
class DeepEPDispatchMode(IntEnum): class DeepEPDispatchMode(IntEnum):
NORMAL = auto() NORMAL = auto()
LOW_LATENCY = auto() LOW_LATENCY = auto()
...@@ -139,7 +189,7 @@ class DeepEPBuffer: ...@@ -139,7 +189,7 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class DeepEPConfig: class DeepEPConfig(BaseDispatcherConfig):
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return hidden_states, topk_idx, topk_weights, previous_event return hidden_states, topk_idx, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: (
( hidden_states,
hidden_states, topk_idx,
topk_idx, topk_weights,
topk_weights, num_recv_tokens_per_expert,
num_recv_tokens_per_expert_list, event,
event, ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
) = self._dispatch_core( event.current_stream_wait() if self.async_finish else ()
hidden_states, topk_idx, topk_weights, previous_event return DeepEPNormalOutput(
) hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
event.current_stream_wait() if self.async_finish else () )
return (
hidden_states,
topk_idx,
topk_weights,
None,
num_recv_tokens_per_expert_list,
None,
None,
None,
)
else:
(
hidden_states,
topk_idx,
topk_weights,
num_recv_tokens_per_expert_list,
event,
) = self._dispatch_core(
hidden_states, topk_idx, topk_weights, previous_event
)
event.current_stream_wait() if self.async_finish else ()
if hidden_states.shape[0] > 0:
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
)
else:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
masked_m = expected_m = None
return (
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m,
expected_m,
)
def _dispatch_core( def _dispatch_core(
self, self,
...@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert_list, num_recv_tokens_per_expert,
self.handle, self.handle,
event, event,
) = buffer.dispatch( ) = buffer.dispatch(
...@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
) )
get_global_expert_distribution_recorder().on_deepep_dispatch_normal( get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert_list, num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert, num_tokens_per_expert=num_tokens_per_expert,
...@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert_list, num_recv_tokens_per_expert,
event, event,
) )
def _deepep_permute(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
fp8_dtype: Optional[torch.dtype] = None,
use_fp8_w8a8: bool = False,
use_block_quant: bool = False,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_block_quant)
else hidden_states.dtype
),
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m masked_m
) )
reorder_topk_ids = seg_indptr = None return DeepEPLLOutput(
return (
hidden_states, hidden_states,
topk_idx, topk_idx,
topk_weights, topk_weights,
reorder_topk_ids,
None,
seg_indptr,
masked_m, masked_m,
expected_m, expected_m,
) )
...@@ -636,7 +587,7 @@ class _Stage(Enum): ...@@ -636,7 +587,7 @@ class _Stage(Enum):
AFTER_COMBINE_A = auto() AFTER_COMBINE_A = auto()
class DeepEPDispatcher: class DeepEPDispatcher(BaseDispatcher):
def __init__( def __init__(
self, self,
group: torch.distributed.ProcessGroup, group: torch.distributed.ProcessGroup,
...@@ -676,7 +627,7 @@ class DeepEPDispatcher: ...@@ -676,7 +627,7 @@ class DeepEPDispatcher:
self._stage = _Stage.INITIAL self._stage = _Stage.INITIAL
def dispatch(self, *args, **kwargs) -> Tuple: def dispatch(self, *args, **kwargs) -> DispatchOutput:
self.dispatch_a(*args, **kwargs) self.dispatch_a(*args, **kwargs)
ret = self.dispatch_b() ret = self.dispatch_b()
return ret return ret
......
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
import torch
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()
deepep_ll = auto()
def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard
def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal
def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll
@runtime_checkable
class DispatchOutput(Protocol):
"""Protocol for dispatch outputs in different formats."""
@property
def format(self) -> DispatchOutputFormat: ...
class BaseDispatcherConfig(ABC):
"""Base class for dispatcher configs."""
pass
class BaseDispatcher(ABC):
"""Base class for dispatchers."""
@abstractmethod
def dispatch(self, *args, **kwargs) -> DispatchOutput:
pass
@abstractmethod
def combine(self, *args, **kwargs) -> torch.Tensor:
pass
from __future__ import annotations
from typing import NamedTuple
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
DispatchOutput,
DispatchOutputFormat,
)
class StandardDispatchOutput(NamedTuple):
"""Standard dispatch output."""
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard
assert isinstance(StandardDispatchOutput, DispatchOutput)
...@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty( topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device (0, self.top_k), dtype=torch.float32, device=hidden_states.device
) )
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
if shared_output is not None: if shared_output is not None:
x = shared_output x = shared_output
...@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
def op_dispatch_a(self, state): def op_dispatch_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value self.experts.deepep_dispatcher.dispatch_a(
self.deepep_dispatcher.dispatch_a(
hidden_states=state.hidden_states_mlp_input, hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"), topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"), topk_weights=state.pop("topk_weights_local"),
...@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module): ...@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
( state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts( state.hidden_states_experts_output = self.experts.moe_impl(
hidden_states=state.pop("hidden_states_experts_input"), dispatch_output=state.dispatch_output,
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_batch=state.forward_batch,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.deepep_dispatcher.combine_a( self.experts.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"), topk_idx=state.dispatch_output.topk_idx,
topk_weights=state.pop("topk_weights_dispatched"), topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output")
def op_combine_b(self, state): def op_combine_b(self, state):
if self.ep_size > 1: if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( state.hidden_states_after_combine = (
tbo_subbatch_index=state.get("tbo_subbatch_index"), self.experts.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
) )
def op_output(self, state): def op_output(self, state):
......
...@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
topk_weights = torch.empty( topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device (0, self.top_k), dtype=torch.float32, device=hidden_states.device
) )
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
topk_weights=topk_weights, topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)
return final_hidden_states return final_hidden_states
def op_gate(self, state): def op_gate(self, state):
...@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def op_dispatch_a(self, state): def op_dispatch_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value self.experts.deepep_dispatcher.dispatch_a(
self.deepep_dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"), hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"), topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"), topk_weights=state.pop("topk_weights_local"),
...@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
( state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
state.hidden_states_experts_input,
state.topk_idx_dispatched,
state.topk_weights_dispatched,
state.reorder_topk_ids,
state.num_recv_tokens_per_expert,
state.seg_indptr,
state.masked_m,
state.expected_m,
) = self.deepep_dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts( state.hidden_states_experts_output = self.experts.moe_impl(
hidden_states=state.pop("hidden_states_experts_input"), dispatch_output=state.dispatch_output,
topk_idx=state.topk_idx_dispatched,
topk_weights=state.topk_weights_dispatched,
reorder_topk_ids=state.pop("reorder_topk_ids"),
seg_indptr=state.pop("seg_indptr"),
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_batch=state.forward_batch,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.deepep_dispatcher.combine_a( self.experts.deepep_dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"), topk_idx=state.dispatch_output.topk_idx,
topk_weights=state.pop("topk_weights_dispatched"), topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch, forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output")
def op_combine_b(self, state): def op_combine_b(self, state):
if self.ep_size > 1: if self.ep_size > 1:
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( state.hidden_states_after_combine = (
tbo_subbatch_index=state.get("tbo_subbatch_index"), self.experts.deepep_dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
) )
def op_output(self, state): def op_output(self, state):
......
from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
from dataclasses import replace from dataclasses import replace
from typing import Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
import torch import torch
...@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy ...@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher: ...@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
def dispatch(self, **kwargs): def dispatch(self, **kwargs) -> DispatchOutput:
return self._execute("dispatch", **kwargs) return self._execute("dispatch", **kwargs)
def dispatch_a(self, **kwargs): def dispatch_a(self, **kwargs):
...@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher: ...@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
def dispatch_b(self, **kwargs): def dispatch_b(self, **kwargs):
return self._execute("dispatch_b", **kwargs) return self._execute("dispatch_b", **kwargs)
def combine(self, **kwargs): def combine(self, **kwargs) -> torch.Tensor:
return self._execute("combine", **kwargs) return self._execute("combine", **kwargs)
def combine_a(self, **kwargs): def combine_a(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment