Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4844fac9
Unverified
Commit
4844fac9
authored
Sep 14, 2025
by
Cheng Wan
Committed by
GitHub
Sep 14, 2025
Browse files
Refactor TopK to ensure readability and extensibility (#9338)
parent
b7d385e8
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
52 additions
and
47 deletions
+52
-47
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+0
-10
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+30
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-12
python/sglang/srt/models/ernie4.py
python/sglang/srt/models/ernie4.py
+1
-1
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+1
-1
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/models/longcat_flash.py
python/sglang/srt/models/longcat_flash.py
+2
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+1
-1
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+2
-2
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+1
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
4844fac9
...
...
@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
raise
ValueError
(
f
"Not Supported DeepEP format
{
dispatch_output
.
format
}
"
)
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]):
if
get_moe_a2a_backend
().
is_deepep
():
return
DeepEPMoE
...
...
@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
return
FusedMoE
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
if
quantization
==
"modelopt_fp4"
:
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFP4MoE
,
)
...
...
@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
except
:
pass
if
should_use_flashinfer_trtllm_moe
():
if
should_use_flashinfer_trtllm_moe
()
and
quant_config
is
not
None
:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return
FlashInferFusedMoE
if
get_moe_runner_backend
().
is_flashinfer_cutlass
():
return
FusedMoE
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
4844fac9
...
...
@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
logger
=
logging
.
getLogger
(
__name__
)
def
_is_fp4_quantization_enabled
():
"""Check if ModelOpt FP4 quantization is enabled."""
try
:
# Use the same simple check that works for class selection
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
return
quantization
==
"modelopt_fp4"
except
:
return
False
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
...
...
python/sglang/srt/layers/moe/topk.py
View file @
4844fac9
...
...
@@ -19,6 +19,7 @@ import math
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Callable
,
NamedTuple
,
Optional
,
...
...
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
is_npu
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization
import
QuantizationConfig
try
:
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
,
routing
except
ImportError
:
...
...
@@ -94,6 +98,7 @@ class TopKConfig:
torch_native
:
bool
=
False
routed_scaling_factor
:
Optional
[
float
]
=
None
apply_routed_scaling_factor_on_output
:
bool
=
False
output_format
:
Optional
[
TopKOutputFormat
]
=
None
# -------------------------------- TopKOutput ---------------------------------------
...
...
@@ -196,9 +201,10 @@ class TopK(CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
force_topk
:
bool
=
Fals
e
,
output_format
:
Optional
[
TopKOutputFormat
]
=
Non
e
,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
...
@@ -207,6 +213,14 @@ class TopK(CustomOp):
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
if
(
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
and
should_use_flashinfer_trtllm_moe
()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias
=
correction_bias
.
to
(
torch
.
bfloat16
)
self
.
topk_config
=
TopKConfig
(
top_k
=
top_k
,
use_grouped_topk
=
use_grouped_topk
,
...
...
@@ -218,11 +232,9 @@ class TopK(CustomOp):
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
output_format
=
output_format
,
)
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
force_topk
=
force_topk
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -248,7 +260,19 @@ class TopK(CustomOp):
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
if
self
.
use_triton_kernels
:
if
self
.
topk_config
.
output_format
is
not
None
:
output_format
=
self
.
topk_config
.
output_format
elif
get_moe_runner_backend
().
is_triton_kernel
():
output_format
=
TopKOutputFormat
.
TRITON_KERNEL
elif
(
should_use_flashinfer_trtllm_moe
()
or
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
):
output_format
=
TopKOutputFormat
.
BYPASSED
else
:
output_format
=
TopKOutputFormat
.
STANDARD
if
output_format
==
TopKOutputFormat
.
TRITON_KERNEL
:
# renormalize=True is equivalent to sm_first=False
routing_data
,
gather_idx
,
scatter_idx
=
routing
(
router_logits
,
...
...
@@ -256,10 +280,7 @@ class TopK(CustomOp):
sm_first
=
not
self
.
topk_config
.
renormalize
,
)
return
TritonKernelTopKOutput
(
routing_data
,
gather_idx
,
scatter_idx
)
elif
not
self
.
force_topk
and
(
should_use_flashinfer_trtllm_moe
()
or
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
):
elif
output_format
==
TopKOutputFormat
.
BYPASSED
:
return
BypassedTopKOutput
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4844fac9
...
...
@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"weight_loader_disable_mmap"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
"enable_custom_logit_processor"
,
"disaggregation_mode"
,
]
...
...
python/sglang/srt/models/bailing_moe.py
View file @
4844fac9
...
...
@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
layer_id
=
self
.
layer_id
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4844fac9
...
...
@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
get_deepep_mode
,
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FusedMoE
,
_is_fp4_quantization_enabled
,
)
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
correction_bias
=
self
.
gate
.
e_score_correction_bias
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
if
_is_fp4_quantization_enabled
()
and
should_use_flashinfer_trtllm_moe
():
correction_bias
=
correction_bias
.
to
(
torch
.
bfloat16
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
...
...
@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
(),
force_topk
=
quant_config
is
None
,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format
=
TopKOutputFormat
.
STANDARD
if
quant_config
is
None
else
None
,
)
self
.
shared_experts_is_int8
=
False
...
...
python/sglang/srt/models/ernie4.py
View file @
4844fac9
...
...
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
4844fac9
...
...
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
...
...
python/sglang/srt/models/gpt_oss.py
View file @
4844fac9
...
...
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
)
self
.
top_k
=
config
.
num_experts_per_tok
experts_type
=
get_moe_impl_class
()
experts_type
=
get_moe_impl_class
(
quant_config
)
extra_kwargs
=
{}
if
experts_type
.
__name__
==
"FusedMoE"
:
quant_config_name
=
(
...
...
python/sglang/srt/models/longcat_flash.py
View file @
4844fac9
...
...
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
)
self
.
topk
.
forward
=
self
.
topk
.
forward_native
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
layer_id
=
self
.
layer_id
,
...
...
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
4844fac9
...
...
@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
renormalize
=
config
.
norm_topk_prob
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
layer_id
=
self
.
layer_id
,
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
num_experts
,
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
4844fac9
...
...
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk
=
False
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
top_k
=
config
.
num_experts_per_tok
,
...
...
python/sglang/srt/models/qwen3_next.py
View file @
4844fac9
...
...
@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.
ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.
fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/step3_vl.py
View file @
4844fac9
...
...
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
use_grouped_topk
=
False
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment