Unverified Commit 4844fac9 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Refactor TopK to ensure readability and extensibility (#9338)

parent b7d385e8
...@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE): ...@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}") raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep(): if get_moe_a2a_backend().is_deepep():
return DeepEPMoE return DeepEPMoE
...@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): ...@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
return FusedMoE return FusedMoE
try: try:
# Check the quantization argument directly # Check the quantization argument directly
quantization = global_server_args_dict.get("quantization") if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
if quantization == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE, FlashInferFP4MoE,
) )
...@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): ...@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
except: except:
pass pass
if should_use_flashinfer_trtllm_moe(): if should_use_flashinfer_trtllm_moe() and quant_config is not None:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return FlashInferFusedMoE return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass(): if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE return FusedMoE
......
...@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe(): ...@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_fp4_quantization_enabled():
"""Check if ModelOpt FP4 quantization is enabled."""
try:
# Use the same simple check that works for class selection
quantization = global_server_args_dict.get("quantization")
return quantization == "modelopt_fp4"
except:
return False
def _get_tile_tokens_dim(num_tokens, top_k, num_experts): def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first. # Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts num_tokens_per_expert = (num_tokens * top_k) // num_experts
......
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import ( from typing import (
TYPE_CHECKING,
Callable, Callable,
NamedTuple, NamedTuple,
Optional, Optional,
...@@ -51,6 +52,9 @@ from sglang.srt.utils import ( ...@@ -51,6 +52,9 @@ from sglang.srt.utils import (
is_npu, is_npu,
) )
if TYPE_CHECKING:
from sglang.srt.layers.quantization import QuantizationConfig
try: try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError: except ImportError:
...@@ -94,6 +98,7 @@ class TopKConfig: ...@@ -94,6 +98,7 @@ class TopKConfig:
torch_native: bool = False torch_native: bool = False
routed_scaling_factor: Optional[float] = None routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False apply_routed_scaling_factor_on_output: bool = False
output_format: Optional[TopKOutputFormat] = None
# -------------------------------- TopKOutput --------------------------------------- # -------------------------------- TopKOutput ---------------------------------------
...@@ -196,9 +201,10 @@ class TopK(CustomOp): ...@@ -196,9 +201,10 @@ class TopK(CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
quant_config: Optional[QuantizationConfig] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False, apply_routed_scaling_factor_on_output: Optional[bool] = False,
force_topk: bool = False, output_format: Optional[TopKOutputFormat] = None,
): ):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
...@@ -207,6 +213,14 @@ class TopK(CustomOp): ...@@ -207,6 +213,14 @@ class TopK(CustomOp):
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
if (
quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias = correction_bias.to(torch.bfloat16)
self.topk_config = TopKConfig( self.topk_config = TopKConfig(
top_k=top_k, top_k=top_k,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -218,11 +232,9 @@ class TopK(CustomOp): ...@@ -218,11 +232,9 @@ class TopK(CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
output_format=output_format,
) )
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.force_topk = force_topk
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -248,7 +260,19 @@ class TopK(CustomOp): ...@@ -248,7 +260,19 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> TopKOutput:
if self.use_triton_kernels: if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernel():
output_format = TopKOutputFormat.TRITON_KERNEL
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
output_format = TopKOutputFormat.BYPASSED
else:
output_format = TopKOutputFormat.STANDARD
if output_format == TopKOutputFormat.TRITON_KERNEL:
# renormalize=True is equivalent to sm_first=False # renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing( routing_data, gather_idx, scatter_idx = routing(
router_logits, router_logits,
...@@ -256,10 +280,7 @@ class TopK(CustomOp): ...@@ -256,10 +280,7 @@ class TopK(CustomOp):
sm_first=not self.topk_config.renormalize, sm_first=not self.topk_config.renormalize,
) )
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif not self.force_topk and ( elif output_format == TopKOutputFormat.BYPASSED:
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
return BypassedTopKOutput( return BypassedTopKOutput(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_multimodal", "enable_multimodal",
"enable_symm_mem", "enable_symm_mem",
"quantization",
"enable_custom_logit_processor", "enable_custom_logit_processor",
"disaggregation_mode", "disaggregation_mode",
] ]
......
...@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module): ...@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=self.top_k,
layer_id=self.layer_id, layer_id=self.layer_id,
......
...@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import ( ...@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
FusedMoE, from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
_is_fp4_quantization_enabled,
)
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
correction_bias = self.gate.e_score_correction_bias
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
correction_bias = correction_bias.to(torch.bfloat16)
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
...@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group, num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=correction_bias, correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None, # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
......
...@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module): ...@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.moe_num_experts, num_experts=config.moe_num_experts,
top_k=config.moe_k, top_k=config.moe_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
......
...@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
......
...@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class() experts_type = get_moe_impl_class(quant_config)
extra_kwargs = {} extra_kwargs = {}
if experts_type.__name__ == "FusedMoE": if experts_type.__name__ == "FusedMoE":
quant_config_name = ( quant_config_name = (
......
...@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module): ...@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
) )
self.topk.forward = self.topk.forward_native self.topk.forward = self.topk.forward_native
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=self.top_k,
layer_id=self.layer_id, layer_id=self.layer_id,
...@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
layer_id=self.layer_id, layer_id=self.layer_id,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
num_experts=config.num_experts, num_experts=config.num_experts,
......
...@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk=False, use_grouped_topk=False,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
......
...@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import ( ...@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
...@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module): ...@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
......
...@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module): ...@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
use_grouped_topk=False, use_grouped_topk=False,
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.moe_num_experts, num_experts=config.moe_num_experts,
top_k=config.moe_top_k, top_k=config.moe_top_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment