Unverified Commit bfc3b3f7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)

parent da5bde4d
...@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper: ...@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
_global_dp_buffer_len: int _global_dp_buffer_len: int
_local_dp_buffer_len: int _local_dp_buffer_len: int
_global_num_tokens: Optional[List[int]] _global_num_tokens: Optional[List[int]]
_is_extend_in_batch: bool
@classmethod @classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
...@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper: ...@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
def get_dp_device(cls) -> torch.device: def get_dp_device(cls) -> torch.device:
return cls._device return cls._device
@classmethod
def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
cls._is_extend_in_batch = is_extend_in_batch
@classmethod
def get_is_extend_in_batch(cls) -> bool:
return cls._is_extend_in_batch
def set_dp_buffer_len( def set_dp_buffer_len(
global_dp_buffer_len: int, global_dp_buffer_len: int,
...@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device: ...@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
return _DpGatheredBufferWrapper.get_dp_device() return _DpGatheredBufferWrapper.get_dp_device()
def set_is_extend_in_batch(is_extend_in_batch: bool):
_DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
def get_is_extend_in_batch() -> bool:
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
......
...@@ -566,7 +566,9 @@ def ep_scatter( ...@@ -566,7 +566,9 @@ def ep_scatter(
scale_hidden_size = ceil_div(scale_hidden_size, 4) scale_hidden_size = ceil_div(scale_hidden_size, 4)
assert m_indices.shape[0] % BLOCK_E == 0 assert m_indices.shape[0] % BLOCK_E == 0
assert recv_x_scale.dtype == output_tensor_scale.dtype assert (
recv_x_scale.dtype == output_tensor_scale.dtype
), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
_fwd_kernel_ep_scatter_1[(grid,)]( _fwd_kernel_ep_scatter_1[(grid,)](
......
...@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader from sglang.srt.utils.offloader import get_offloader
...@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE): ...@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE):
self.deepep_mode = get_deepep_mode() self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
if self.deepep_mode.enable_low_latency() and not _is_npu: if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm # NPU supports low_latency deepep without deepgemm
assert ( assert (
...@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE): ...@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_output: TopKOutput,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
forward_shared_experts=None, forward_shared_experts=None,
alt_stream=None, alt_stream=None,
disable_sbo=False, disable_sbo=False,
): ):
# We have to call SBO inside MoE to be compatible with hooks used in offloading # We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo( return single_batch_overlap.execute_sbo(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args # SBO args
experts=self, experts=self,
forward_shared_experts=forward_shared_experts, forward_shared_experts=forward_shared_experts,
...@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE): ...@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE):
def dispatch( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_output: TopKOutput,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
): ):
return self.deepep_dispatcher.dispatch( return self.dispatcher.dispatch(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
forward_batch=forward_batch,
input_global_scale=(
self.w13_input_scale_quant
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
and self.quant_method.enable_flashinfer_cutedsl_moe
and CUTEDSL_MOE_NVFP4_DISPATCH
else None
),
) )
def moe_impl( def run_moe_core(
self, self,
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None, down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
...@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE): ...@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE):
def combine( def combine(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional[Dict[str, Any]] = None, overlap_args: Optional[Dict[str, Any]] = None,
): ):
return self.deepep_dispatcher.combine( return self.dispatcher.combine(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
forward_batch=forward_batch,
overlap_args=overlap_args, overlap_args=overlap_args,
) )
...@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE): ...@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE):
self, self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
): ):
hidden_states, topk_idx, topk_weights = ( hidden_states, topk_ids, topk_weights = (
dispatch_output.hidden_states, dispatch_output.hidden_states,
dispatch_output.topk_idx, dispatch_output.topk_ids,
dispatch_output.topk_weights, dispatch_output.topk_weights,
) )
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
...@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE): ...@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE):
# in original deepep, idx == -1 meaning invalid and will not be processed. # in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid # aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe # (idx == num_local_experts) meaning not used in aiter fused_moe
topk_idx_copy = topk_idx.to(torch.int32) topk_ids_copy = topk_ids.to(torch.int32)
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
return fused_moe( return fused_moe(
hidden_states, hidden_states,
self.w13_weight, self.w13_weight,
self.w2_weight, self.w2_weight,
topk_weights, topk_weights,
topk_idx_copy, topk_ids_copy,
w1_scale=self.w13_weight_scale_inv, w1_scale=self.w13_weight_scale_inv,
w2_scale=self.w2_weight_scale_inv, w2_scale=self.w2_weight_scale_inv,
quant_type=QuantType.per_128x128, quant_type=QuantType.per_128x128,
...@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE): ...@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE):
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalOutput,
): ):
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( (
dispatch_output hidden_states,
) hidden_states_scale,
hidden_states_fp8, hidden_states_scale = hidden_states_fp8 topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None: if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16() return hidden_states.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert) all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0: if all_tokens <= 0:
return hidden_states_fp8.bfloat16() return hidden_states.bfloat16()
M, K = hidden_states_fp8.size() M, K = hidden_states.size()
N = self.w13_weight.size(1) N = self.w13_weight.size(1)
scale_block_size = 128 scale_block_size = 128
...@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE): ...@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE):
), ),
) )
hidden_states_fp8_shape = hidden_states_fp8.shape hidden_states_shape = hidden_states.shape
hidden_states_fp8_device = hidden_states_fp8.device hidden_states_device = hidden_states.device
hidden_states_fp8_dtype = hidden_states_fp8.dtype hidden_states_dtype = hidden_states.dtype
input_tensor = [ input_tensor = [
torch.empty( torch.empty(
(all_tokens, K), (all_tokens, K),
device=hidden_states_fp8.device, device=hidden_states.device,
dtype=hidden_states_fp8.dtype, dtype=hidden_states.dtype,
), ),
( (
# TODO check whether need `zeros` # TODO check whether need `zeros`
torch.zeros( torch.zeros(
(ceil_div(K // 128, 4), all_tokens), (ceil_div(K // 128, 4), all_tokens),
device=hidden_states_fp8.device, device=hidden_states.device,
dtype=torch.int, dtype=torch.int,
).transpose(0, 1) ).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty( else torch.empty(
(all_tokens, K // 128), (all_tokens, K // 128),
device=hidden_states_fp8.device, device=hidden_states.device,
dtype=torch.float32, dtype=torch.float32,
) )
), ),
] ]
m_indices = torch.empty( m_indices = torch.empty(
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32 all_tokens, device=hidden_states.device, dtype=torch.int32
) )
output_index = torch.empty_like(topk_idx) output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage: if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce( num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
...@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE): ...@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE):
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter( ep_scatter(
hidden_states_fp8, hidden_states,
hidden_states_scale, hidden_states_scale,
topk_idx, topk_ids,
num_recv_tokens_per_expert_gpu, num_recv_tokens_per_expert_gpu,
expert_start_loc, expert_start_loc,
input_tensor[0], input_tensor[0],
...@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE): ...@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE):
output_index, output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
) )
dispose_tensor(hidden_states_fp8) dispose_tensor(hidden_states)
gateup_output = torch.empty( gateup_output = torch.empty(
(all_tokens, N), (all_tokens, N),
device=hidden_states_fp8_device, device=hidden_states_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
...@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE): ...@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE):
del gateup_output del gateup_output
down_output = torch.empty( down_output = torch.empty(
(all_tokens, K), (all_tokens, K),
device=hidden_states_fp8_device, device=hidden_states_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
...@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE): ...@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE):
del down_input_fp8, down_input_scale del down_input_fp8, down_input_scale
gather_out = torch.empty( gather_out = torch.empty(
hidden_states_fp8_shape, hidden_states_shape,
device=hidden_states_fp8_device, device=hidden_states_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
return gather_out return gather_out
...@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE): ...@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE):
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs], down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
): ):
hidden_states, _, _, masked_m, _ = dispatch_output hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
output = self.quant_method.apply_without_routing_weights( output = self.quant_method.apply_without_routing_weights(
layer=self, layer=self,
x=hidden_states, x=(hidden_states, hidden_states_scale),
masked_m=masked_m, masked_m=masked_m,
moe_runner_config=self.moe_runner_config, moe_runner_config=self.moe_runner_config,
down_gemm_overlap_args=down_gemm_overlap_args, down_gemm_overlap_args=down_gemm_overlap_args,
...@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE): ...@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE):
self, self,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
): ):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
assert (
hidden_states_scale.dtype == torch.float32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
# GroupGemm-0 # GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size() num_groups, m, k = hidden_states.size()
n = self.w13_weight.size(1) n = self.w13_weight.size(1)
expected_m = min(expected_m, m) expected_m = min(expected_m, m)
gateup_output = torch.empty( gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
) )
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
hidden_states_fp8, (hidden_states, hidden_states_scale),
self.w13_weight_fp8, self.w13_weight_fp8,
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
) )
dispose_tensor(hidden_states_fp8[0]) dispose_tensor(hidden_states)
# Act # Act
down_input = torch.empty( down_input = torch.empty(
...@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE): ...@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE):
def _forward_normal(dispatch_output: DeepEPNormalOutput): def _forward_normal(dispatch_output: DeepEPNormalOutput):
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput) assert isinstance(dispatch_output, DeepEPNormalOutput)
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
dispatch_output
if isinstance(hidden_states, tuple): )
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to( group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device hidden_states.device
...@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE): ...@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE):
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states], x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)], weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale], # per_token_scale=[hidden_states_scale],
split_item=2, split_item=2,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
...@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE): ...@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE):
)[0] )[0]
else: else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"): if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant( hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
hidden_states hidden_states
) )
# gmm1: gate_up_proj # gmm1: gate_up_proj
...@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE): ...@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE):
x=[hidden_states], x=[hidden_states],
weight=[self.w13_weight], weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)], scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale], per_token_scale=[hidden_states_scale],
split_item=2, split_item=2,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
...@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE): ...@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE):
def _forward_ll(dispatch_output: DeepEPLLOutput): def _forward_ll(dispatch_output: DeepEPLLOutput):
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput) assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output (
hidden_states,
if isinstance(hidden_states, tuple): hidden_states_scale,
per_token_scale = hidden_states[1] topk_ids,
hidden_states = hidden_states[0] topk_weights,
group_list,
_,
) = dispatch_output
group_list = group_list.to(torch.int64) group_list = group_list.to(torch.int64)
...@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE): ...@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE):
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states], x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)], weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale], # per_token_scale=[hidden_states_scale],
split_item=2, split_item=2,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
...@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE): ...@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE):
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states, x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32), weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale, activation_scale=hidden_states_scale,
bias=None, bias=None,
quant_scale=None, quant_scale=None,
quant_offset=None, quant_offset=None,
......
...@@ -11,14 +11,19 @@ from sglang.srt.distributed import ( ...@@ -11,14 +11,19 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size, get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunnerConfig, MoeRunnerConfig,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
from sglang.srt.layers.moe.token_dispatcher.standard import ( from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardDispatcher, StandardDispatcher,
StandardDispatchOutput, StandardDispatchOutput,
...@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod ...@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts): ...@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
return tile_tokens_dim return tile_tokens_dim
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none():
return StandardDispatcher(moe_runner_config)
elif a2a_backend.is_deepep():
return MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
num_local_experts=moe_runner_config.num_local_experts,
hidden_size=moe_runner_config.hidden_size,
params_dtype=moe_runner_config.params_dtype,
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
else:
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
CHANNEL = "channel" CHANNEL = "channel"
...@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module): ...@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_experts = num_experts self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map_cpu = None
self.expert_map_gpu = None
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
...@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module): ...@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module):
assert num_experts % self.moe_ep_size == 0 assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size self.num_local_experts = num_experts // self.moe_ep_size
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
# Create a expert map for the local experts
self.expert_map_cpu[
self.moe_ep_rank
* self.num_local_experts : (self.moe_ep_rank + 1)
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
assert intermediate_size % self.moe_tp_size == 0 assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
...@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module): ...@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module):
) )
self.quant_method.create_moe_runner(self, self.moe_runner_config) self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher() self.dispatcher = create_moe_dispatcher(self.moe_runner_config)
self.should_fuse_routed_scaling_factor_in_topk = isinstance( self.should_fuse_routed_scaling_factor_in_topk = isinstance(
self.quant_method, ModelOptNvFp4FusedMoEMethod self.quant_method, ModelOptNvFp4FusedMoEMethod
...@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module): ...@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module):
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map_cpu is None: start_idx = self.moe_ep_rank * self.num_local_experts
return expert_id end_idx = (self.moe_ep_rank + 1) * self.num_local_experts
return self.expert_map_cpu[expert_id].item() if start_idx <= expert_id < end_idx:
return expert_id - start_idx
else:
return -1
def weight_loader( def weight_loader(
self, self,
...@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module): ...@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module):
origin_hidden_states_dim = hidden_states.shape[-1] origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None assert self.quant_method is not None
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
# If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None:
if TopKOutputChecker.format_is_standard(topk_output):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
dispatch_output = self.dispatcher.dispatch( dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output hidden_states=hidden_states, topk_output=topk_output
) )
# TODO: consider using symmetric memory combine_input = self.run_moe_core(
combine_input = self.quant_method.apply(
layer=self,
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
**kwargs, **kwargs,
) )
final_hidden_states = self.dispatcher.combine(combine_input) final_hidden_states = self.dispatcher.combine(combine_input)
# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[ final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim ..., :origin_hidden_states_dim
].contiguous() ].contiguous()
...@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module): ...@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module):
return final_hidden_states return final_hidden_states
def run_moe_core(self, dispatch_output: DispatchOutput, **kwargs) -> CombineInput:
# TODO: consider using symmetric memory
return self.quant_method.apply(
layer=self,
dispatch_output=dispatch_output,
**kwargs,
)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
cls, cls,
......
...@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import ( ...@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
) )
from sglang.srt.layers.moe.token_dispatcher.standard import ( from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput, StandardCombineInput,
StandardDispatcher,
StandardDispatchOutput, StandardDispatchOutput,
) )
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"MooncakeCombineInput", "MooncakeCombineInput",
"MooncakeDispatchOutput", "MooncakeDispatchOutput",
"MooncakeEPDispatcher", "MooncakeEPDispatcher",
"StandardDispatcher",
"StandardDispatchOutput", "StandardDispatchOutput",
"StandardCombineInput", "StandardCombineInput",
"DeepEPConfig", "DeepEPConfig",
......
...@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum): ...@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
class DispatchOutput(Protocol): class DispatchOutput(Protocol):
"""Protocol for dispatch outputs in different formats.""" """Protocol for dispatch outputs in different formats."""
# TODO: add hidden_states to the protocol hidden_states: torch.Tensor
@property @property
def format(self) -> DispatchOutputFormat: ... def format(self) -> DispatchOutputFormat: ...
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
...@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( ...@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import ( from sglang.srt.layers.moe.utils import (
DeepEPMode, DeepEPMode,
get_deepep_config, get_deepep_config,
...@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto ...@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__) ...@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__)
class DeepEPNormalOutput(NamedTuple): class DeepEPNormalOutput(NamedTuple):
"""DeepEP normal dispatch output.""" """DeepEP normal dispatch output."""
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] hidden_states: torch.Tensor
# hidden_states_scale hidden_states_scale: Optional[torch.Tensor]
topk_idx: torch.Tensor topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int] num_recv_tokens_per_expert: List[int]
...@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple): ...@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
class DeepEPLLOutput(NamedTuple): class DeepEPLLOutput(NamedTuple):
"""DeepEP low latency dispatch output.""" """DeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] hidden_states: torch.Tensor
topk_idx: torch.Tensor hidden_states_scale: Optional[torch.Tensor]
topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
masked_m: torch.Tensor masked_m: torch.Tensor
expected_m: int expected_m: int
...@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase: ...@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase:
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_output: TopKOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
): ):
raise NotImplementedError raise NotImplementedError
...@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase: ...@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase:
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional["CombineOverlapArgs"],
): ):
...@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self.async_finish = async_finish self.async_finish = async_finish
self.src2dst = None self.src2dst = None
self.quant_config = {}
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_output: TopKOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
): ):
topk_idx = topk_idx.to(torch.int64) topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64)
if ( if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass() and not get_moe_runner_backend().is_cutlass()
...@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
) )
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event return hidden_states, topk_ids, topk_weights, previous_event
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
( (
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
event, event,
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else () event.current_stream_wait() if self.async_finish else ()
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states
else:
hidden_states_scale = None
return DeepEPNormalOutput( return DeepEPNormalOutput(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) )
def _dispatch_core( def _dispatch_core(
self, self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
previous_event, previous_event,
): ):
...@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
is_token_in_rank, is_token_in_rank,
previous_event, previous_event,
) = buffer.get_dispatch_layout( ) = buffer.get_dispatch_layout(
topk_idx, topk_ids,
self.num_experts, self.num_experts,
previous_event=previous_event, previous_event=previous_event,
async_finish=self.async_finish, async_finish=self.async_finish,
...@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
( (
recv_x, recv_x,
recv_topk_idx, recv_topk_ids,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
self.handle, self.handle,
event, event,
) = buffer.dispatch( ) = buffer.dispatch(
x, x,
topk_idx=topk_idx, topk_idx=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
...@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return ( return (
recv_x, recv_x,
recv_topk_idx, recv_topk_ids,
recv_topk_weights, recv_topk_weights,
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
event, event,
...@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional["CombineOverlapArgs"],
): ):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states output = hidden_states
else: else:
if hidden_states.shape[0] > 0: raise NotImplementedError() # triton runner was supported but it's temporarily disabled
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event return output, previous_event
...@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self.num_experts, self.num_experts,
) )
def set_quant_config(self, quant_config: dict):
self.quant_config = quant_config
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def __init__(self, return_recv_hook: bool, **kwargs): def __init__(self, return_recv_hook: bool, **kwargs):
...@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
""" """
self.return_recv_hook = return_recv_hook self.return_recv_hook = return_recv_hook
self.device_module = torch.get_device_module() self.device_module = torch.get_device_module()
self.quant_config = {}
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_output: TopKOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
): ):
buffer = self._get_buffer() buffer = self._get_buffer()
topk_idx = topk_idx.to(torch.int64) topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64)
expected_m = ( expected_m = (
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
+ self.num_experts + self.num_experts
) // self.num_experts ) // self.num_experts
hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states, hidden_states,
input_global_scale, topk_ids,
topk_idx,
) )
return ( return (
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def dispatch_b( def dispatch_b(
self, self,
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m masked_m
) )
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states
else:
hidden_states_scale = None
deepep_output = DeepEPLLOutput( deepep_output = DeepEPLLOutput(
hidden_states, hidden_states,
topk_idx, hidden_states_scale,
topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def _dispatch_core( def _dispatch_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_ids: torch.Tensor,
topk_idx: torch.Tensor,
): ):
use_nvfp4 = use_fp8 = False use_nvfp4 = use_fp8 = False
input_global_scale = self.quant_config.get("input_global_scale", None)
if input_global_scale is not None: if input_global_scale is not None:
use_nvfp4 = True use_nvfp4 = True
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"): elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
...@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = ( packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch( buffer.low_latency_dispatch(
hidden_states, hidden_states,
topk_idx, topk_ids,
self.num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank,
self.num_experts, self.num_experts,
use_fp8=use_fp8, use_fp8=use_fp8,
...@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional["CombineOverlapArgs"],
): ):
hidden_states, event, hook = self._combine_core( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
overlap_args=overlap_args, overlap_args=overlap_args,
) )
...@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def _combine_core( def _combine_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"], overlap_args: Optional["CombineOverlapArgs"],
): ):
...@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
with ctx: with ctx:
combined_hidden_states, event, hook = buffer.low_latency_combine( combined_hidden_states, event, hook = buffer.low_latency_combine(
x=hidden_states, x=hidden_states,
topk_idx=topk_idx, topk_idx=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
handle=self.handle, handle=self.handle,
async_finish=not self.return_recv_hook, async_finish=not self.return_recv_hook,
...@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self.num_experts, self.num_experts,
) )
def set_quant_config(self, quant_config: dict):
self.quant_config = quant_config
@dataclass @dataclass
class _Stage(Enum): class _Stage(Enum):
...@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_output: TopKOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
): ):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_batch).dispatch_a( inner_state = self._get_impl().dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
input_global_scale=input_global_scale, topk_output=topk_output,
topk_idx=topk_idx,
topk_weights=topk_weights,
) )
self._dispatch_intermediate_state = forward_batch, inner_state self._dispatch_intermediate_state = inner_state
def dispatch_b(self): def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_batch, inner_state = self._dispatch_intermediate_state inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl(forward_batch).dispatch_b(*inner_state) return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple: def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs) self.combine_a(*args, **kwargs)
...@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional["CombineOverlapArgs"] = None, overlap_args: Optional["CombineOverlapArgs"] = None,
): ):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_batch).combine_a( inner_state = self._get_impl().combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
overlap_args=overlap_args, overlap_args=overlap_args,
) )
self._combine_intermediate_state = forward_batch, inner_state self._combine_intermediate_state = inner_state
def combine_b(self): def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_batch, inner_state = self._combine_intermediate_state inner_state = self._combine_intermediate_state
del self._combine_intermediate_state del self._combine_intermediate_state
return self._get_impl(forward_batch).combine_b(*inner_state) return self._get_impl().combine_b(*inner_state)
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase: def _get_impl(self) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve( is_extend_in_batch = get_is_extend_in_batch()
forward_batch.is_extend_in_batch resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
)
if resolved_deepep_mode == DeepEPMode.NORMAL: if resolved_deepep_mode == DeepEPMode.NORMAL:
return self._normal_dispatcher return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
...@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
def _update_stage(self, old_stage, new_stage): def _update_stage(self, old_stage, new_stage):
assert self._stage == old_stage assert self._stage == old_stage
self._stage = new_stage self._stage = new_stage
def set_quant_config(self, quant_config: dict):
if self.deepep_mode.enable_low_latency():
self._low_latency_dispatcher.set_quant_config(quant_config)
if self.deepep_mode.enable_normal():
self._normal_dispatcher.set_quant_config(quant_config)
...@@ -5,6 +5,7 @@ from dataclasses import dataclass ...@@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple from typing import NamedTuple, Optional, Tuple
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher, BaseDispatcher,
CombineInput, CombineInput,
...@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( ...@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var
...@@ -27,16 +29,15 @@ from enum import Enum, auto ...@@ -27,16 +29,15 @@ from enum import Enum, auto
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MooncakeDispatchOutput(NamedTuple): class MooncakeDispatchOutput(NamedTuple):
"""Mooncake EP dispatch output.""" """Mooncake EP dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] hidden_states: torch.Tensor
topk_idx: torch.Tensor hidden_states_scale: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
masked_m: torch.Tensor masked_m: torch.Tensor
expected_m: int expected_m: int
...@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl: ...@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl:
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_output: TopKOutput,
topk_weights: torch.Tensor,
): ):
topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
buffer = self._get_buffer() buffer = self._get_buffer()
topk_idx = topk_idx.to(torch.int64) topk_ids = topk_ids.to(torch.int64)
expected_m = ( expected_m = (
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
+ self.num_experts + self.num_experts
) // self.num_experts ) // self.num_experts
hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states, hidden_states,
topk_idx, topk_ids,
use_fp8=True, use_fp8=True,
) )
return ( return (
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl: ...@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl:
def dispatch_b( def dispatch_b(
self, self,
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl: ...@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl:
return MooncakeDispatchOutput( return MooncakeDispatchOutput(
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
masked_m, masked_m,
expected_m, expected_m,
...@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl: ...@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl:
def _dispatch_core( def _dispatch_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
use_fp8: bool = False, use_fp8: bool = False,
): ):
buffer = self._get_buffer() buffer = self._get_buffer()
packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
buffer.dispatch( buffer.dispatch(
hidden_states, hidden_states,
topk_idx, topk_ids,
self.active_ranks, self.active_ranks,
self.num_max_dispatch_tokens_per_rank, self.num_max_dispatch_tokens_per_rank,
self.num_experts, self.num_experts,
...@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl: ...@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl:
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
hidden_states, event, hook = self._combine_core( hidden_states, event, hook = self._combine_core(
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
) )
return hidden_states, event, hook return hidden_states, event, hook
...@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl: ...@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl:
def _combine_core( def _combine_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
buffer = self._get_buffer() buffer = self._get_buffer()
combined_hidden_states, event, hook = buffer.combine( combined_hidden_states, event, hook = buffer.combine(
hidden_states, hidden_states,
topk_idx, topk_ids,
topk_weights, topk_weights,
self.active_ranks, self.active_ranks,
-1 if self.first_execution else self.timeout_us, -1 if self.first_execution else self.timeout_us,
...@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher): ...@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher):
def dispatch_a( def dispatch_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_global_scale: Optional[torch.Tensor], topk_output: TopKOutput,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
): ):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_batch).dispatch_a( inner_state = self._get_impl().dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
) )
self._dispatch_intermediate_state = forward_batch, inner_state self._dispatch_intermediate_state = inner_state
def dispatch_b(self): def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_batch, inner_state = self._dispatch_intermediate_state inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl(forward_batch).dispatch_b(*inner_state) return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple: def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs) self.combine_a(*args, **kwargs)
...@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher): ...@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional = None, overlap_args: Optional = None,
): ):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_batch).combine_a( inner_state = self._get_impl().combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_ids=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
) )
self._combine_intermediate_state = forward_batch, inner_state self._combine_intermediate_state = inner_state
def combine_b(self): def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_batch, inner_state = self._combine_intermediate_state inner_state = self._combine_intermediate_state
del self._combine_intermediate_state del self._combine_intermediate_state
return self._get_impl(forward_batch).combine_b(*inner_state) return self._get_impl().combine_b(*inner_state)
def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl: def _get_impl(self) -> _MooncakeEPDispatcherImpl:
resolved_deepep_mode = self.deepep_mode.resolve( is_extend_in_batch = get_is_extend_in_batch()
forward_batch.is_extend_in_batch resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
)
if resolved_deepep_mode == DeepEPMode.NORMAL: if resolved_deepep_mode == DeepEPMode.NORMAL:
raise NotImplementedError raise NotImplementedError
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
...@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher): ...@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher):
def _update_stage(self, old_stage, new_stage): def _update_stage(self, old_stage, new_stage):
assert self._stage == old_stage assert self._stage == old_stage
self._stage = new_stage self._stage = new_stage
def set_quant_config(self, quant_config: dict):
pass
...@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple ...@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
import torch import torch
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
)
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher, BaseDispatcher,
CombineInput, CombineInput,
...@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import ( ...@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.moe.utils import get_moe_runner_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput) ...@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
class StandardDispatcher(BaseDispatcher): class StandardDispatcher(BaseDispatcher):
def __init__(self, moe_runner_config: MoeRunnerConfig):
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.enable_flashinfer_cutlass_moe = (
get_moe_runner_backend().is_flashinfer_cutlass()
)
self.num_experts = moe_runner_config.num_experts
self.num_local_experts = moe_runner_config.num_local_experts
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.local_expert_mapping = None
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, topk_output: TopKOutput self, hidden_states: torch.Tensor, topk_output: TopKOutput
) -> DispatchOutput: ) -> DispatchOutput:
if (
self.moe_ep_size > 1
and not self.enable_flashinfer_cutlass_moe
and TopKOutputChecker.format_is_standard(topk_output)
):
if self.local_expert_mapping is None:
self.local_expert_mapping = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cuda"
)
self.local_expert_mapping[
self.moe_ep_rank
* self.num_local_experts : (self.moe_ep_rank + 1)
* self.num_local_experts
] = torch.arange(
0, self.num_local_experts, dtype=torch.int32, device="cuda"
)
if self.local_expert_mapping is not None:
if TopKOutputChecker.format_is_standard(topk_output):
topk_output = topk_output._replace(
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
)
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
return StandardDispatchOutput( return StandardDispatchOutput(
hidden_states=hidden_states, topk_output=topk_output hidden_states=hidden_states, topk_output=topk_output
) )
...@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher): ...@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
# TODO: this branch should be removed in the future # TODO: this branch should be removed in the future
assert isinstance(combine_input, torch.Tensor) assert isinstance(combine_input, torch.Tensor)
return combine_input return combine_input
def set_quant_config(self, quant_config: dict):
pass
...@@ -365,9 +365,10 @@ class TopK(CustomOp): ...@@ -365,9 +365,10 @@ class TopK(CustomOp):
def empty_topk_output(self, device: torch.device) -> TopKOutput: def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device) topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
# FIXME: router_logits should be of size (0, num_experts)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_idx, router_logits) return StandardTopKOutput(topk_weights, topk_ids, router_logits)
# ------------------------------- TopK implementation ------------------------------------- # ------------------------------- TopK implementation -------------------------------------
......
...@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
(1 / w2_input_scale).to(torch.float32), requires_grad=False (1 / w2_input_scale).to(torch.float32), requires_grad=False
) )
layer.dispatcher.set_quant_config(
{"input_global_scale": layer.w13_input_scale_quant}
)
# Validate weight scales # Validate weight scales
for name, weight_scale in [ for name, weight_scale in [
("w13", layer.w13_weight_scale), ("w13", layer.w13_weight_scale),
......
...@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
hidden_states, topk_idx, topk_weights = ( hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states, dispatch_output.hidden_states,
dispatch_output.topk_idx, dispatch_output.topk_ids,
dispatch_output.topk_weights, dispatch_output.topk_weights,
) )
if isinstance(hidden_states, tuple): if isinstance(hidden_states, tuple):
......
...@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len, set_dp_buffer_len,
set_is_extend_in_batch,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
...@@ -639,6 +640,7 @@ class CudaGraphRunner: ...@@ -639,6 +640,7 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_is_extend_in_batch(False)
kwargs = {} kwargs = {}
if ( if (
......
...@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank, get_attention_dp_rank,
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len, set_dp_buffer_len,
set_is_extend_in_batch,
) )
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
...@@ -688,6 +689,7 @@ class ForwardBatch: ...@@ -688,6 +689,7 @@ class ForwardBatch:
self.global_dp_buffer_len = buffer_len self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
set_is_extend_in_batch(self.is_extend_in_batch)
bs = self.batch_size bs = self.batch_size
......
...@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
set_dp_buffer_len, set_dp_buffer_len,
set_is_extend_in_batch,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
...@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner: ...@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner:
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
# It is True in this context but we need to set it to use low latency deepep mode.
set_is_extend_in_batch(False)
kwargs = {} kwargs = {}
with set_forward_context(forward_batch, self.attention_layers): with set_forward_context(forward_batch, self.attention_layers):
......
...@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module): ...@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_output = self.topk(
hidden_states, hidden_states,
router_logits, router_logits,
num_token_non_padded=forward_batch.num_token_non_padded, num_token_non_padded=forward_batch.num_token_non_padded,
...@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module): ...@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
), ),
) )
else: else:
topk_idx = torch.full( topk_output = self.topk.empty_topk_output(hidden_states.device)
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch,
) )
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states += shared_output
return final_hidden_states return final_hidden_states
......
...@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import ( ...@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
...@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
...@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module):
else None else None
) )
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_a2a_moe = ( self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
) )
...@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module):
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
if not self._fuse_shared_experts_inside_sbo: if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_output = self.topk(
hidden_states, hidden_states,
router_logits, router_logits,
num_token_non_padded=forward_batch.num_token_non_padded, num_token_non_padded=forward_batch.num_token_non_padded,
...@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
else: else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output( topk_output = self.topk.empty_topk_output(hidden_states.device)
hidden_states.device
)
if self._fuse_shared_experts_inside_sbo: if self._fuse_shared_experts_inside_sbo:
shared_output = None shared_output = None
...@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
forward_batch=forward_batch,
**( **(
dict( dict(
forward_shared_experts=_forward_shared_experts_and_put_results, forward_shared_experts=_forward_shared_experts_and_put_results,
...@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
state.topk_weights_local, state.topk_idx_local, _ = self.topk( state.topk_output = self.topk(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
num_token_non_padded=state.forward_batch.num_token_non_padded, num_token_non_padded=state.forward_batch.num_token_non_padded,
...@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
else: else:
state.topk_idx_local = torch.full( state.topk_output = self.topk.empty_topk_output(hidden_states.device)
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state): def op_dispatch_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.deepep_dispatcher.dispatch_a( self.experts.dispatcher.dispatch_a(
hidden_states=state.hidden_states_mlp_input, hidden_states=state.hidden_states_mlp_input,
input_global_scale=None, topk_output=state.pop("topk_output"),
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -992,32 +963,29 @@ class DeepseekV2MoE(nn.Module): ...@@ -992,32 +963,29 @@ class DeepseekV2MoE(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( state.dispatch_output = self.experts.dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts.moe_impl( state.hidden_states_experts_output = self.experts.run_moe_core(
dispatch_output=state.dispatch_output, dispatch_output=state.dispatch_output,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.deepep_dispatcher.combine_a( self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.dispatch_output.topk_idx, topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights, topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output") state.pop("dispatch_output")
def op_combine_b(self, state): def op_combine_b(self, state):
if self.ep_size > 1: if self.ep_size > 1:
state.hidden_states_after_combine = ( state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
self.experts.deepep_dispatcher.combine_b( tbo_subbatch_index=state.get("tbo_subbatch_index"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
) )
def op_output(self, state): def op_output(self, state):
......
...@@ -27,7 +27,6 @@ from sglang.srt.distributed import ( ...@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import ( ...@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import ( ...@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2MoE, DeepseekV2MoE,
) )
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
...@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
else None else None
) )
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_a2a_moe = ( self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
) )
......
...@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_output = self.topk(
hidden_states, hidden_states,
router_logits, router_logits,
num_token_non_padded=forward_batch.num_token_non_padded, num_token_non_padded=forward_batch.num_token_non_padded,
...@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
), ),
) )
else: else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output( topk_output = self.topk.empty_topk_output(hidden_states.device)
hidden_states.device
)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
forward_batch=forward_batch,
) )
if shared_output is not None: if shared_output is not None:
......
...@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_output = self.topk(
hidden_states, hidden_states,
router_logits, router_logits,
num_token_non_padded=forward_batch.num_token_non_padded, num_token_non_padded=forward_batch.num_token_non_padded,
...@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
), ),
) )
else: else:
topk_idx = torch.full( topk_output = self.topk.empty_topk_output(hidden_states.device)
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_output=topk_output,
topk_weights=topk_weights,
forward_batch=forward_batch,
) )
return final_hidden_states return final_hidden_states
...@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
state.topk_weights_local, state.topk_idx_local, _ = self.topk( state.topk_output = self.topk(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
num_token_non_padded=state.forward_batch.num_token_non_padded, num_token_non_padded=state.forward_batch.num_token_non_padded,
...@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
), ),
) )
else: else:
state.topk_idx_local = torch.full( state.topk_output = self.topk.empty_topk_output(hidden_states.device)
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
state.topk_weights_local = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
def op_dispatch_a(self, state): def op_dispatch_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.deepep_dispatcher.dispatch_a( self.experts.dispatcher.dispatch_a(
hidden_states=state.pop("hidden_states_mlp_input"), hidden_states=state.pop("hidden_states_mlp_input"),
topk_idx=state.pop("topk_idx_local"), topk_output=state.pop("topk_output"),
topk_weights=state.pop("topk_weights_local"),
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
...@@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer( with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id self.layer_id
): ):
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( state.dispatch_output = self.experts.dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts.moe_impl( state.hidden_states_experts_output = self.experts.run_moe_core(
dispatch_output=state.dispatch_output, dispatch_output=state.dispatch_output,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.deepep_dispatcher.combine_a( self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.dispatch_output.topk_idx, topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights, topk_weights=state.dispatch_output.topk_weights,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output") state.pop("dispatch_output")
def op_combine_b(self, state): def op_combine_b(self, state):
if self.ep_size > 1: if self.ep_size > 1:
state.hidden_states_after_combine = ( state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
self.experts.deepep_dispatcher.combine_b( tbo_subbatch_index=state.get("tbo_subbatch_index"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
) )
def op_output(self, state): def op_output(self, state):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment