"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "735a0ca2e480b40fc714751b73848c08cf4eed43"
Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
...@@ -11,6 +11,7 @@ import triton ...@@ -11,6 +11,7 @@ import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe, fused_moe,
get_config_dtype_str, get_config_dtype_str,
...@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_default_config, get_default_config,
get_moe_configs, get_moe_configs,
) )
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -117,17 +119,23 @@ def benchmark_config( ...@@ -117,17 +119,23 @@ def benchmark_config(
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_output = select_experts(x, input_gating, topk, renormalize=True) topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int): def prepare(i: int):
input_gating = gating_output[i] input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk, renormalize=True) new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights) topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids) topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits) topk_output.router_logits.copy_(new_topk_output.router_logits)
def run(): def run():
from sglang.srt.layers.moe.fused_moe_triton import override_config moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config): with override_config(config):
fused_moe( fused_moe(
...@@ -135,7 +143,7 @@ def benchmark_config( ...@@ -135,7 +143,7 @@ def benchmark_config(
w1, w1,
w2, w2,
topk_output, topk_output,
inplace=True, moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
......
...@@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 | | `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False | | `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
| `--init-expert-location` | Initial location of EP experts. | trivial | | `--init-expert-location` | Initial location of EP experts. | trivial |
| `--enable-eplb` | Enable EPLB algorithm. | False | | `--enable-eplb` | Enable EPLB algorithm. | False |
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto | | `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
...@@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False | | `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
| `--disable-fast-image-processor` | Disable fast image processor. | False | | `--disable-fast-image-processor` | Disable fast image processor. | False |
| `--enable-return-hidden-states` | Enable returning hidden states. | False | | `--enable-return-hidden-states` | Enable returning hidden states. | False |
| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False |
## Debug tensor dumps ## Debug tensor dumps
......
...@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig ...@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): ...@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
disable_cuda_graph=model_runner.server_args.disable_cuda_graph, disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None, speculative_num_draft_tokens=None,
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
model_runner.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
) )
......
...@@ -25,7 +25,6 @@ import torch ...@@ -25,7 +25,6 @@ import torch
import torch.distributed import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var from sglang.srt.utils import Withable, get_bool_env_var
...@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC): ...@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
) )
if server_args.expert_distribution_recorder_mode == "stat_approx": if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.moe_a2a_backend is not None and ( if server_args.moe_a2a_backend != "none" and (
server_args.deepep_mode == "normal" server_args.deepep_mode == "normal"
): ):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else: else:
raise NotImplementedError raise NotImplementedError
if server_args.moe_a2a_backend is not None: if server_args.moe_a2a_backend != "none":
if server_args.deepep_mode == "normal": if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency": elif server_args.deepep_mode == "low_latency":
......
...@@ -17,7 +17,7 @@ from enum import Enum, auto ...@@ -17,7 +17,7 @@ from enum import Enum, auto
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
import torch.distributed import torch
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer, get_global_dp_buffer,
get_local_dp_buffer, get_local_dp_buffer,
) )
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -111,7 +112,7 @@ class LayerScatterModes: ...@@ -111,7 +112,7 @@ class LayerScatterModes:
if context.is_layer_sparse: if context.is_layer_sparse:
return ( return (
ScatterMode.SCATTERED ScatterMode.SCATTERED
if not global_server_args_dict["moe_a2a_backend"].is_standard() if not get_moe_a2a_backend().is_none()
else ScatterMode.FULL else ScatterMode.FULL
) )
else: else:
......
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.utils import (
DeepEPMode,
MoeA2ABackend,
MoeRunnerBackend,
get_deepep_config,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_flashinfer_trtllm_moe,
)
__all__ = [
"DeepEPMode",
"MoeA2ABackend",
"MoeRunnerConfig",
"MoeRunnerBackend",
"initialize_moe_config",
"get_moe_a2a_backend",
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
"get_deepep_config",
]
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
...@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import ( from sglang.srt.layers.quantization.fp8 import Fp8Config
Fp8Config,
Fp8MoEMethod,
get_tile_tokens_dim,
)
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
...@@ -89,12 +90,11 @@ class EPMoE(FusedMoE): ...@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False, with_bias: bool = False,
): ):
super().__init__( super().__init__(
...@@ -106,13 +106,12 @@ class EPMoE(FusedMoE): ...@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
top_k=top_k, top_k=top_k,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=prefix, prefix=prefix,
activation=activation, activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias, with_bias=with_bias,
) )
...@@ -163,7 +162,8 @@ class EPMoE(FusedMoE): ...@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
) )
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
...@@ -327,8 +327,8 @@ class EPMoE(FusedMoE): ...@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
m_max * self.start_expert_id, m_max * self.start_expert_id,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
if self.routed_scaling_factor is not None: if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.routed_scaling_factor output *= self.moe_runner_config.routed_scaling_factor
return output return output
...@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE): ...@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
): ):
super().__init__( super().__init__(
num_experts=num_experts, num_experts=num_experts,
...@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE): ...@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=prefix, prefix=prefix,
activation=activation, activation=activation,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
self.deepep_mode = deepep_mode self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file # TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
...@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE): ...@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
num_local_experts=self.num_local_experts, num_local_experts=self.num_local_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
params_dtype=params_dtype, params_dtype=params_dtype,
deepep_mode=deepep_mode, deepep_mode=self.deepep_mode,
async_finish=True, # TODO async_finish=True, # TODO
return_recv_hook=True, return_recv_hook=True,
) )
...@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE): ...@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
) )
def moe_impl(self, dispatch_output: DispatchOutput): def moe_impl(self, dispatch_output: DispatchOutput):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter: if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output) return self.forward_aiter(dispatch_output)
if _is_npu: if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
return self.forward_npu(dispatch_output) return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal(): if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
elif dispatch_output.format.is_deepep_ll(): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output) return self.forward_deepgemm_masked(dispatch_output)
else: else:
...@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE): ...@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
def forward_aiter( def forward_aiter(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
): ):
hidden_states, topk_idx, topk_weights = ( hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states, dispatch_output.hidden_states,
...@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE): ...@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
quant_type=QuantType.per_128x128, quant_type=QuantType.per_128x128,
activation=( activation=(
ActivationType.Silu ActivationType.Silu
if self.activation == "silu" if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu else ActivationType.Gelu
), ),
expert_mask=self.expert_mask, expert_mask=self.expert_mask,
...@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE): ...@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
) )
hidden_states_fp8, hidden_states_scale = hidden_states_fp8 hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None: if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16() return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert) all_tokens = sum(num_recv_tokens_per_expert)
...@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE): ...@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
): ):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
# GroupGemm-0 # GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size() num_groups, m, k = hidden_states_fp8[0].size()
...@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE): ...@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
return DeepEPMoE return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements) # NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP # Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False): if get_moe_runner_backend().is_flashinfer_trtllm():
try: try:
# Check the quantization argument directly # Check the quantization argument directly
quantization = global_server_args_dict.get("quantization") quantization = global_server_args_dict.get("quantization")
...@@ -803,7 +804,7 @@ def get_moe_impl_class(): ...@@ -803,7 +804,7 @@ def get_moe_impl_class():
if should_use_flashinfer_trtllm_moe(): if should_use_flashinfer_trtllm_moe():
return FlashInferFusedMoE return FlashInferFusedMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]: if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE return FusedMoE
if get_moe_expert_parallel_world_size() > 1: if get_moe_expert_parallel_world_size() > 1:
return EPMoE return EPMoE
......
...@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile. ...@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
""" """
from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native( def fused_moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
...@@ -33,12 +27,12 @@ def fused_moe_forward_native( ...@@ -33,12 +27,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids] w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
if activation == "silu": if moe_runner_config.activation == "silu":
x1 = F.silu(x1) x1 = F.silu(x1)
elif activation == "gelu": elif moe_runner_config.activation == "gelu":
x1 = F.gelu(x1) x1 = F.gelu(x1)
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
...@@ -47,16 +41,11 @@ def fused_moe_forward_native( ...@@ -47,16 +41,11 @@ def fused_moe_forward_native(
def moe_forward_native( def moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
...@@ -72,12 +61,12 @@ def moe_forward_native( ...@@ -72,12 +61,12 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy() tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu": if moe_runner_config.activation == "silu":
act = SiluAndMul() act = SiluAndMul()
elif activation == "gelu": elif moe_runner_config.activation == "gelu":
act = GeluAndMul() act = GeluAndMul()
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
outputs = [] outputs = []
start_idx = 0 start_idx = 0
......
...@@ -2,17 +2,20 @@ ...@@ -2,17 +2,20 @@
"""Fused MoE kernel.""" """Fused MoE kernel."""
from __future__ import annotations
import functools import functools
import json import json
import logging import logging
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant, scaled_fp8_quant,
...@@ -1025,8 +1028,8 @@ def inplace_fused_experts( ...@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -1053,8 +1056,8 @@ def inplace_fused_experts( ...@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
block_shape, block_shape,
False, False,
routed_scaling_factor, routed_scaling_factor,
activation_alpha, gemm1_alpha,
swiglu_limit, gemm1_limit,
) )
...@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake( ...@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> None: ) -> None:
pass pass
...@@ -1119,8 +1122,8 @@ def outplace_fused_experts( ...@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -1147,8 +1150,8 @@ def outplace_fused_experts( ...@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
block_shape, block_shape,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_limit=gemm1_limit,
) )
...@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake( ...@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1194,12 +1197,10 @@ def fused_experts( ...@@ -1194,12 +1197,10 @@ def fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -1212,14 +1213,10 @@ def fused_experts( ...@@ -1212,14 +1213,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if inplace: if moe_runner_config.inplace:
assert not no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
hidden_states, hidden_states,
w1, w1,
...@@ -1228,8 +1225,8 @@ def fused_experts( ...@@ -1228,8 +1225,8 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
activation, moe_runner_config.activation,
apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
...@@ -1242,9 +1239,9 @@ def fused_experts( ...@@ -1242,9 +1239,9 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
routed_scaling_factor, moe_runner_config.routed_scaling_factor,
activation_alpha, moe_runner_config.gemm1_alpha,
swiglu_limit, moe_runner_config.gemm1_clamp_limit,
) )
return hidden_states return hidden_states
else: else:
...@@ -1256,8 +1253,8 @@ def fused_experts( ...@@ -1256,8 +1253,8 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
activation, moe_runner_config.activation,
apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
...@@ -1270,10 +1267,10 @@ def fused_experts( ...@@ -1270,10 +1267,10 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine, no_combine=moe_runner_config.no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=moe_runner_config.routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_limit=moe_runner_config.gemm1_clamp_limit,
) )
...@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): ...@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
@torch.compile @torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit): def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
gate, up = x[..., ::2], x[..., 1::2] gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit) gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-limit, max=limit) up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1) return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
def fused_experts_impl( def fused_experts_impl(
...@@ -1402,8 +1399,8 @@ def fused_experts_impl( ...@@ -1402,8 +1399,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
): ):
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
...@@ -1533,12 +1530,12 @@ def fused_experts_impl( ...@@ -1533,12 +1530,12 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
if activation_alpha is not None: if gemm1_alpha is not None:
assert swiglu_limit is not None assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit( intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N), intermediate_cache1.view(-1, N),
activation_alpha, gemm1_alpha,
swiglu_limit, gemm1_limit,
) )
elif _is_cuda: elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
...@@ -1547,10 +1544,8 @@ def fused_experts_impl( ...@@ -1547,10 +1544,8 @@ def fused_experts_impl(
intermediate_cache2, intermediate_cache1.view(-1, N) intermediate_cache2, intermediate_cache1.view(-1, N)
) )
elif activation == "gelu": elif activation == "gelu":
assert ( assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
activation_alpha is None assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
if _is_cuda: if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
...@@ -1641,12 +1636,10 @@ def fused_moe( ...@@ -1641,12 +1636,10 @@ def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -1659,10 +1652,6 @@ def fused_moe( ...@@ -1659,10 +1652,6 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1672,11 +1661,10 @@ def fused_moe( ...@@ -1672,11 +1661,10 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts. - topk_output (StandardTopKOutput): The top-k output of the experts.
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
- b1 (Optional[torch.Tensor]): Optional bias for w1. - b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2. - b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
...@@ -1696,9 +1684,9 @@ def fused_moe( ...@@ -1696,9 +1684,9 @@ def fused_moe(
a2. a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise - block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization. quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation - gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
function. function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation - gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
function. function.
Returns: Returns:
...@@ -1710,11 +1698,9 @@ def fused_moe( ...@@ -1710,11 +1698,9 @@ def fused_moe(
w1, w1,
w2, w2,
topk_output, topk_output,
moe_runner_config=moe_runner_config,
b1=b1, b1=b1,
b2=b2, b2=b2,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
...@@ -1727,8 +1713,4 @@ def fused_moe( ...@@ -1727,8 +1713,4 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import datetime
import glob
import logging import logging
import os
import sys
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ...@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe MoeRunnerConfig,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module): ...@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module): ...@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False, gemm1_alpha: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None,
swiglu_limit: Optional[float] = None,
use_weight_loader_fused: bool = False, use_weight_loader_fused: bool = False,
with_bias=False, with_bias=False,
): ):
...@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module): ...@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None self.expert_map_cpu = None
self.expert_map_gpu = None self.expert_map_gpu = None
# For activation self.moe_runner_config = MoeRunnerConfig(
self.activation_alpha = activation_alpha activation=activation,
self.swiglu_limit = swiglu_limit apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.") logger.warning("Disable flashinfer MoE when quantization config is None.")
...@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module): ...@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
* self.num_local_experts * self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0 assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels self.use_triton_kernels
...@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module): ...@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get( self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
"enable_flashinfer_mxfp4_moe", False
)
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if ( if (
self.quant_config is not None self.quant_config is not None
and self.quant_config.get_name() == "mxfp4" and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe and self.use_flashinfer_mxfp4_moe
): ):
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
self.quant_method.create_weights( self.quant_method.create_weights(
...@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module): ...@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
) )
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1] origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None assert self.quant_method is not None
...@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module): ...@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU. # If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None and isinstance( if self.expert_map_gpu is not None:
topk_output, StandardTopKOutput if TopKOutputChecker.format_is_standard(topk_output):
): topk_output = topk_output._replace(
topk_output = topk_output._replace( topk_ids=self.expert_map_gpu[topk_output.topk_ids]
topk_ids=self.expert_map_gpu[topk_output.topk_ids] )
) elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
# Matrix multiply. # Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
topk_output=topk_output, topk_output=topk_output,
activation=self.activation, moe_runner_config=self.moe_runner_config,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
) )
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
...@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module): ...@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
class FlashInferFusedMoE(FusedMoE): class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: tuple): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe assert self.use_flashinfer_trtllm_moe
assert ( assert (
self.activation == "silu" self.activation == "silu"
...@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
self.num_fused_shared_experts == 0 self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple assert TopKOutputChecker.format_is_bypassed(topk_output)
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits( final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, topk_output=topk_output,
activation=self.activation, moe_runner_config=self.moe_runner_config,
routed_scaling_factor=self.routed_scaling_factor,
) )
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
...@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer.""" """FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass # Helper: quantize hidden states to FP4 each forward pass
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
...@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
return hs_fp4, hs_sf return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
"""Forward pass using FP4 TRTLLM kernel. """Forward pass using FP4 TRTLLM kernel.
Args: Args:
hidden_states: Input tensor hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode topk_output: TopKOutput object with Bypassed format
""" """
assert TopKOutputChecker.format_is_bypassed(topk_output)
# TRTLLM mode expects (TopK_config, router_logits) tuple router_logits = topk_output.router_logits
if not isinstance(topk_output, tuple) or len(topk_output) != 2: topk_config = topk_output.topk_config
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
...@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
result = trtllm_fp4_block_scale_moe( result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype), routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4, hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data, gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
...@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
output1_scale_gate_scalar=self.g1_alphas.data, output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data, output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=topk_config.top_k,
n_group=self.num_expert_group, n_group=topk_config.num_expert_group,
topk_group=self.topk_group, topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition, intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts, local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim( tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts hidden_states.shape[0], topk_config.top_k, self.num_local_experts
), ),
routing_method_type=RoutingMethodType.DeepSeekV3, routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True, do_finalize=True,
......
...@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx ...@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -55,8 +56,7 @@ def triton_kernel_moe_forward( ...@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
inplace: bool = False, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
...@@ -69,7 +69,10 @@ def triton_kernel_moe_forward( ...@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert topk_output.format.is_triton_kernel() from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
...@@ -79,8 +82,8 @@ def triton_kernel_moe_forward( ...@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
routing_data, routing_data,
gather_idx, gather_idx,
scatter_idx, scatter_idx,
inplace=inplace, inplace=False, # triton kernel doesn't support inplace
activation=activation, activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
...@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward( ...@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
w2_pcg, w2_pcg,
b2: torch.Tensor, b2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
inplace: bool = False, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward( ...@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert topk_output.format.is_triton_kernel() from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias( return triton_kernel_fused_experts_with_bias(
...@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward( ...@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
routing_data=routing_data, routing_data=routing_data,
gather_indx=gather_idx, gather_indx=gather_idx,
scatter_indx=scatter_idx, scatter_indx=scatter_idx,
inplace=inplace, inplace=False, # triton kernel doesn't support inplace
activation=activation, activation=moe_runner_config.activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward( ...@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
activation_alpha=activation_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
) )
...@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias( ...@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None, gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported" assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported" assert expert_map == None, "expert_map is not supported"
...@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias( ...@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
act = FusedActivation( act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit), (gemm1_alpha, gemm1_clamp_limit),
2, 2,
) )
......
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
__all__ = ["MoeRunnerConfig"]
from dataclasses import dataclass
from typing import Optional
@dataclass
class MoeRunnerConfig:
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
no_combine: bool = False
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None
...@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( ...@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
DispatchOutput, DispatchOutput,
DispatchOutputChecker,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.token_dispatcher.deepep import ( from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig, DeepEPConfig,
DeepEPDispatcher, DeepEPDispatcher,
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalOutput, DeepEPNormalOutput,
) )
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
__all__ = [ __all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher", "BaseDispatcher",
"BaseDispatcherConfig", "BaseDispatcherConfig",
"DispatchOutput", "DispatchOutput",
"DispatchOutputFormat", "DispatchOutputFormat",
"DispatchOutputChecker",
"StandardDispatchOutput",
"DeepEPConfig", "DeepEPConfig",
"DeepEPDispatcher", "DeepEPDispatcher",
"DeepEPNormalOutput", "DeepEPNormalOutput",
......
...@@ -2,35 +2,76 @@ from __future__ import annotations ...@@ -2,35 +2,76 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto from enum import Enum, auto
from typing import Protocol, runtime_checkable from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self): class DispatchOutputChecker:
return self == MoEA2ABackend.none
def is_deepep(self): @staticmethod
return self == MoEA2ABackend.deepep def format_is_standard(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]:
return dispatch_output.format.is_deepep_normal()
@staticmethod
def format_is_deepep_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]:
return dispatch_output.format.is_deepep_ll()
@staticmethod
def format_is_deepep(
dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum): class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto() STANDARD = auto()
deepep_ll = auto() DEEPEP_NORMAL = auto()
DEEPEP_LL = auto()
ASCENT_LL = auto()
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard return self == DispatchOutputFormat.STANDARD
def is_deepep_normal(self) -> bool: def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal return self == DispatchOutputFormat.DEEPEP_NORMAL
def is_deepep_ll(self) -> bool: def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll return self == DispatchOutputFormat.DEEPEP_LL
def is_deepep(self) -> bool:
return self in [
DispatchOutputFormat.DEEPEP_NORMAL,
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable @runtime_checkable
......
...@@ -2,27 +2,17 @@ from __future__ import annotations ...@@ -2,27 +2,17 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
...@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple): ...@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple): class DeepEPLLOutput(NamedTuple):
...@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple): ...@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple): class AscendDeepEPLLOutput(NamedTuple):
...@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple): ...@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput) assert isinstance(DeepEPNormalOutput, DispatchOutput)
...@@ -128,8 +118,8 @@ class DeepEPBuffer: ...@@ -128,8 +118,8 @@ class DeepEPBuffer:
hidden_size: int, hidden_size: int,
param_bytes: int, param_bytes: int,
deepep_mode: DeepEPMode, deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None, num_max_dispatch_tokens_per_rank: int = -1,
num_experts: int = None, num_experts: int = -1,
): ):
if cls._buffer is not None: if cls._buffer is not None:
return cls._buffer return cls._buffer
...@@ -156,8 +146,8 @@ class DeepEPBuffer: ...@@ -156,8 +146,8 @@ class DeepEPBuffer:
num_rdma_bytes, num_rdma_bytes,
) )
if deepep_mode.enable_low_latency(): if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank is not None assert num_max_dispatch_tokens_per_rank != -1
assert num_experts is not None and num_experts % group.size() == 0 assert num_experts != -1 and num_experts % group.size() == 0
num_rdma_bytes = max( num_rdma_bytes = max(
Buffer.get_low_latency_rdma_size_hint( Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
...@@ -181,7 +171,7 @@ class DeepEPBuffer: ...@@ -181,7 +171,7 @@ class DeepEPBuffer:
).multi_processor_count ).multi_processor_count
if ( if (
(deepep_mode != DeepEPMode.LOW_LATENCY) (deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"] and not is_tbo_enabled()
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
): ):
logger.warning( logger.warning(
...@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig): ...@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
_instance = None _instance = None
def __init__(self): def __init__(self):
config_str = global_server_args_dict["deepep_config"] config_str = get_deepep_config()
if config_str: if config_str:
config_parsed = load_json_config(config_str) config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
......
...@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple): ...@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput) assert isinstance(StandardDispatchOutput, DispatchOutput)
...@@ -14,9 +14,18 @@ ...@@ -14,9 +14,18 @@
from __future__ import annotations from __future__ import annotations
import logging
import math import math
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable from typing import (
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import ( ...@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -43,6 +55,7 @@ try: ...@@ -43,6 +55,7 @@ try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError: except ImportError:
pass pass
logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -65,13 +78,48 @@ if _use_aiter: ...@@ -65,13 +78,48 @@ if _use_aiter:
if _is_npu: if _is_npu:
import torch_npu import torch_npu
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: int = 0
num_expert_group: int = 0
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
# -------------------------------- TopKOutput --------------------------------------- # -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return topk_output.format.is_bypassed()
class TopKOutputFormat(Enum): class TopKOutputFormat(Enum):
STANDARD = auto() STANDARD = auto()
TRITON_KERNEL = auto() TRITON_KERNEL = auto()
BYPASSED = auto()
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD return self == TopKOutputFormat.STANDARD
...@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum): ...@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
def is_triton_kernel(self) -> bool: def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
return self == TopKOutputFormat.BYPASSED
@runtime_checkable @runtime_checkable
class TopKOutput(Protocol): class TopKOutput(Protocol):
...@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple): ...@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
return TopKOutputFormat.TRITON_KERNEL return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK --------------------------------------- # -------------------------------- TopK ---------------------------------------
...@@ -124,8 +189,8 @@ class TopK(CustomOp): ...@@ -124,8 +189,8 @@ class TopK(CustomOp):
top_k: int, top_k: int,
*, *,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: int = 0,
num_expert_group: Optional[int] = None, num_expert_group: int = 0,
renormalize: bool = True, renormalize: bool = True,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
...@@ -136,19 +201,23 @@ class TopK(CustomOp): ...@@ -136,19 +201,23 @@ class TopK(CustomOp):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__() super().__init__()
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk self.topk_config = TopKConfig(
self.renormalize = renormalize top_k=top_k,
self.topk_group = topk_group use_grouped_topk=use_grouped_topk,
self.num_expert_group = num_expert_group renormalize=renormalize,
self.num_fused_shared_experts = num_fused_shared_experts topk_group=topk_group,
self.custom_routing_function = custom_routing_function num_expert_group=num_expert_group,
self.correction_bias = correction_bias num_fused_shared_experts=num_fused_shared_experts,
self.routed_scaling_factor = routed_scaling_factor custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] routed_scaling_factor=routed_scaling_factor,
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
def forward_native( def forward_native(
self, self,
...@@ -158,20 +227,11 @@ class TopK(CustomOp): ...@@ -158,20 +227,11 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> TopKOutput:
torch_native = True self.topk_config.torch_native = True
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
...@@ -187,24 +247,28 @@ class TopK(CustomOp): ...@@ -187,24 +247,28 @@ class TopK(CustomOp):
if self.use_triton_kernels: if self.use_triton_kernels:
# renormalize=True is equivalent to sm_first=False # renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing( routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, sm_first=not self.renormalize router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
) )
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else: else:
torch_native = False self.topk_config.torch_native = False
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
...@@ -220,15 +284,7 @@ class TopK(CustomOp): ...@@ -220,15 +284,7 @@ class TopK(CustomOp):
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
...@@ -244,35 +300,29 @@ class TopK(CustomOp): ...@@ -244,35 +300,29 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1] global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256: if global_num_experts == 256 and self.topk_config.renormalize is False:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k( return torch_npu.npu_moe_gating_top_k(
router_logits, router_logits,
k=self.top_k, k=self.topk_config.top_k,
bias=self.correction_bias.to(torch.float32), bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_group, k_group=self.topk_config.topk_group,
group_count=self.num_expert_group, group_count=self.topk_config.num_expert_group,
group_select_mode=1, group_select_mode=1,
renorm=0, renorm=0,
norm_type=1, norm_type=1,
routed_scaling_factor=1, routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20), eps=float(1e-20),
) )
else: else:
torch_native = True self.topk_config.torch_native = True
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
...@@ -670,20 +720,23 @@ else: ...@@ -670,20 +720,23 @@ else:
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, topk_config: TopKConfig,
*, *,
use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> StandardTopKOutput:
top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
num_expert_group = topk_config.num_expert_group
renormalize = topk_config.renormalize
num_fused_shared_experts = topk_config.num_fused_shared_experts
custom_routing_function = topk_config.custom_routing_function
correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits, router_logits=router_logits,
......
from __future__ import annotations
import importlib.util import importlib.util
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import logger
if TYPE_CHECKING:
@lru_cache(maxsize=1) from sglang.srt.server_args import ServerArgs
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
class MoeA2ABackend(Enum): class MoeA2ABackend(Enum):
STANDARD = ("standard", "none") NONE = "none"
DEEPEP = "deepep" DEEPEP = "deepep"
@classmethod @classmethod
def _missing_(cls, value): def _missing_(cls, value):
if value is None: if value is None:
return cls.STANDARD return cls.NONE
for member in cls: for member in cls:
if value in member.value: if value == member.value:
return member return member
raise ValueError(f"No {cls.__name__} member for value {value}") raise ValueError(f"No {cls.__name__} member for value {value}")
def is_none(self):
return self == MoeA2ABackend.NONE
def is_deepep(self): def is_deepep(self):
return self == MoeA2ABackend.DEEPEP return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD class MoeRunnerBackend(Enum):
AUTO = "auto"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_triton(self):
return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self):
return self == MoeRunnerBackend.TRITON_KERNEL
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
class DeepEPMode(Enum): class DeepEPMode(Enum):
NORMAL = "normal" NORMAL = "normal"
LOW_LATENCY = "low_latency" LOW_LATENCY = "low_latency"
AUTO = "auto" AUTO = "auto"
def enable_normal(self): def enable_normal(self) -> bool:
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO] return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self): def enable_low_latency(self) -> bool:
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO] return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool): def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
if self != DeepEPMode.AUTO: if self != DeepEPMode.AUTO:
return self return self
...@@ -57,3 +82,96 @@ class DeepEPMode(Enum): ...@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
return DeepEPMode.NORMAL return DeepEPMode.NORMAL
else: else:
return DeepEPMode.LOW_LATENCY return DeepEPMode.LOW_LATENCY
def is_normal(self) -> bool:
return self == DeepEPMode.NORMAL
def is_low_latency(self) -> bool:
return self == DeepEPMode.LOW_LATENCY
def is_auto(self) -> bool:
return self == DeepEPMode.AUTO
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None)
return MOE_A2A_BACKEND
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
return MOE_RUNNER_BACKEND
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto")
return DEEPEP_MODE
def get_deepep_config() -> str:
global DEEPEP_CONFIG
if DEEPEP_CONFIG is None:
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
DEEPEP_CONFIG = ""
return DEEPEP_CONFIG
def is_tbo_enabled() -> bool:
global IS_TBO_ENABLED
if IS_TBO_ENABLED is None:
logger.warning("IS_TBO_ENABLED is not initialized, using False")
IS_TBO_ENABLED = False
return IS_TBO_ENABLED
def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
logger.warning(
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
)
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
...@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod ...@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
...@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert (
assert activation == "silu", "Only SiLU activation is supported." moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment