Unverified Commit 4844fac9 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Refactor TopK to ensure readability and extensibility (#9338)

parent b7d385e8
......@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep():
return DeepEPMoE
......@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
return FusedMoE
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
if quantization == "modelopt_fp4":
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)
......@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
except:
pass
if should_use_flashinfer_trtllm_moe():
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
......
......@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
logger = logging.getLogger(__name__)
def _is_fp4_quantization_enabled():
"""Check if ModelOpt FP4 quantization is enabled."""
try:
# Use the same simple check that works for class selection
quantization = global_server_args_dict.get("quantization")
return quantization == "modelopt_fp4"
except:
return False
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
......
......@@ -19,6 +19,7 @@ import math
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
Callable,
NamedTuple,
Optional,
......@@ -51,6 +52,9 @@ from sglang.srt.utils import (
is_npu,
)
if TYPE_CHECKING:
from sglang.srt.layers.quantization import QuantizationConfig
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
......@@ -94,6 +98,7 @@ class TopKConfig:
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
output_format: Optional[TopKOutputFormat] = None
# -------------------------------- TopKOutput ---------------------------------------
......@@ -196,9 +201,10 @@ class TopK(CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
quant_config: Optional[QuantizationConfig] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
force_topk: bool = False,
output_format: Optional[TopKOutputFormat] = None,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
......@@ -207,6 +213,14 @@ class TopK(CustomOp):
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
if (
quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias = correction_bias.to(torch.bfloat16)
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
......@@ -218,11 +232,9 @@ class TopK(CustomOp):
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
output_format=output_format,
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.force_topk = force_topk
def forward_native(
self,
hidden_states: torch.Tensor,
......@@ -248,7 +260,19 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
if self.use_triton_kernels:
if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernel():
output_format = TopKOutputFormat.TRITON_KERNEL
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
output_format = TopKOutputFormat.BYPASSED
else:
output_format = TopKOutputFormat.STANDARD
if output_format == TopKOutputFormat.TRITON_KERNEL:
# renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing(
router_logits,
......@@ -256,10 +280,7 @@ class TopK(CustomOp):
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif not self.force_topk and (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
elif output_format == TopKOutputFormat.BYPASSED:
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
......
......@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"weight_loader_disable_mmap",
"enable_multimodal",
"enable_symm_mem",
"quantization",
"enable_custom_logit_processor",
"disaggregation_mode",
]
......
......@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=self.num_experts,
top_k=self.top_k,
layer_id=self.layer_id,
......
......@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
_is_fp4_quantization_enabled,
)
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
......@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("experts", prefix),
)
correction_bias = self.gate.e_score_correction_bias
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
correction_bias = correction_bias.to(torch.bfloat16)
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
......@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=correction_bias,
correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
)
self.shared_experts_is_int8 = False
......
......@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
correction_bias=self.gate.e_score_correction_bias,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.moe_num_experts,
top_k=config.moe_k,
hidden_size=config.hidden_size,
......
......@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
......
......@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
)
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class()
experts_type = get_moe_impl_class(quant_config)
extra_kwargs = {}
if experts_type.__name__ == "FusedMoE":
quant_config_name = (
......
......@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
)
self.topk.forward = self.topk.forward_native
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=self.num_experts,
top_k=self.top_k,
layer_id=self.layer_id,
......@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
renormalize=config.norm_topk_prob,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
layer_id=self.layer_id,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts,
......
......@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk=False,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok,
......
......@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
......
......@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
use_grouped_topk=False,
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment