Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
4844fac9
"tests/vscode:/vscode.git/clone" did not exist on "ef02219679f59cb0e0380e3620af52f4243503ce"
Unverified
Commit
4844fac9
authored
Sep 14, 2025
by
Cheng Wan
Committed by
GitHub
Sep 14, 2025
Browse files
Refactor TopK to ensure readability and extensibility (#9338)
parent
b7d385e8
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
52 additions
and
47 deletions
+52
-47
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+0
-10
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+30
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-12
python/sglang/srt/models/ernie4.py
python/sglang/srt/models/ernie4.py
+1
-1
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+1
-1
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/models/longcat_flash.py
python/sglang/srt/models/longcat_flash.py
+2
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+1
-1
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+2
-2
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+1
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
4844fac9
...
...
@@ -888,7 +888,7 @@ class DeepEPMoE(EPMoE):
raise
ValueError
(
f
"Not Supported DeepEP format
{
dispatch_output
.
format
}
"
)
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]):
if
get_moe_a2a_backend
().
is_deepep
():
return
DeepEPMoE
...
...
@@ -901,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
return
FusedMoE
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
if
quantization
==
"modelopt_fp4"
:
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFP4MoE
,
)
...
...
@@ -911,7 +910,8 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
except
:
pass
if
should_use_flashinfer_trtllm_moe
():
if
should_use_flashinfer_trtllm_moe
()
and
quant_config
is
not
None
:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return
FlashInferFusedMoE
if
get_moe_runner_backend
().
is_flashinfer_cutlass
():
return
FusedMoE
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
4844fac9
...
...
@@ -74,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
logger
=
logging
.
getLogger
(
__name__
)
def
_is_fp4_quantization_enabled
():
"""Check if ModelOpt FP4 quantization is enabled."""
try
:
# Use the same simple check that works for class selection
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
return
quantization
==
"modelopt_fp4"
except
:
return
False
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
...
...
python/sglang/srt/layers/moe/topk.py
View file @
4844fac9
...
...
@@ -19,6 +19,7 @@ import math
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Callable
,
NamedTuple
,
Optional
,
...
...
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
is_npu
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization
import
QuantizationConfig
try
:
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
,
routing
except
ImportError
:
...
...
@@ -94,6 +98,7 @@ class TopKConfig:
torch_native
:
bool
=
False
routed_scaling_factor
:
Optional
[
float
]
=
None
apply_routed_scaling_factor_on_output
:
bool
=
False
output_format
:
Optional
[
TopKOutputFormat
]
=
None
# -------------------------------- TopKOutput ---------------------------------------
...
...
@@ -196,9 +201,10 @@ class TopK(CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
force_topk
:
bool
=
Fals
e
,
output_format
:
Optional
[
TopKOutputFormat
]
=
Non
e
,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
...
@@ -207,6 +213,14 @@ class TopK(CustomOp):
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
if
(
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
and
should_use_flashinfer_trtllm_moe
()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias
=
correction_bias
.
to
(
torch
.
bfloat16
)
self
.
topk_config
=
TopKConfig
(
top_k
=
top_k
,
use_grouped_topk
=
use_grouped_topk
,
...
...
@@ -218,11 +232,9 @@ class TopK(CustomOp):
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
output_format
=
output_format
,
)
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
force_topk
=
force_topk
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -248,7 +260,19 @@ class TopK(CustomOp):
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
)
->
TopKOutput
:
if
self
.
use_triton_kernels
:
if
self
.
topk_config
.
output_format
is
not
None
:
output_format
=
self
.
topk_config
.
output_format
elif
get_moe_runner_backend
().
is_triton_kernel
():
output_format
=
TopKOutputFormat
.
TRITON_KERNEL
elif
(
should_use_flashinfer_trtllm_moe
()
or
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
):
output_format
=
TopKOutputFormat
.
BYPASSED
else
:
output_format
=
TopKOutputFormat
.
STANDARD
if
output_format
==
TopKOutputFormat
.
TRITON_KERNEL
:
# renormalize=True is equivalent to sm_first=False
routing_data
,
gather_idx
,
scatter_idx
=
routing
(
router_logits
,
...
...
@@ -256,10 +280,7 @@ class TopK(CustomOp):
sm_first
=
not
self
.
topk_config
.
renormalize
,
)
return
TritonKernelTopKOutput
(
routing_data
,
gather_idx
,
scatter_idx
)
elif
not
self
.
force_topk
and
(
should_use_flashinfer_trtllm_moe
()
or
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
):
elif
output_format
==
TopKOutputFormat
.
BYPASSED
:
return
BypassedTopKOutput
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4844fac9
...
...
@@ -105,7 +105,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"weight_loader_disable_mmap"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
"enable_custom_logit_processor"
,
"disaggregation_mode"
,
]
...
...
python/sglang/srt/models/bailing_moe.py
View file @
4844fac9
...
...
@@ -246,7 +246,7 @@ class BailingMoESparseMoeBlock(nn.Module):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
layer_id
=
self
.
layer_id
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4844fac9
...
...
@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
get_deepep_mode
,
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FusedMoE
,
_is_fp4_quantization_enabled
,
)
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
correction_bias
=
self
.
gate
.
e_score_correction_bias
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
if
_is_fp4_quantization_enabled
()
and
should_use_flashinfer_trtllm_moe
():
correction_bias
=
correction_bias
.
to
(
torch
.
bfloat16
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
...
...
@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
(),
force_topk
=
quant_config
is
None
,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format
=
TopKOutputFormat
.
STANDARD
if
quant_config
is
None
else
None
,
)
self
.
shared_experts_is_int8
=
False
...
...
python/sglang/srt/models/ernie4.py
View file @
4844fac9
...
...
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
4844fac9
...
...
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
...
...
python/sglang/srt/models/gpt_oss.py
View file @
4844fac9
...
...
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
)
self
.
top_k
=
config
.
num_experts_per_tok
experts_type
=
get_moe_impl_class
()
experts_type
=
get_moe_impl_class
(
quant_config
)
extra_kwargs
=
{}
if
experts_type
.
__name__
==
"FusedMoE"
:
quant_config_name
=
(
...
...
python/sglang/srt/models/longcat_flash.py
View file @
4844fac9
...
...
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
)
self
.
topk
.
forward
=
self
.
topk
.
forward_native
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
layer_id
=
self
.
layer_id
,
...
...
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
4844fac9
...
...
@@ -143,7 +143,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
renormalize
=
config
.
norm_topk_prob
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
layer_id
=
self
.
layer_id
,
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
num_experts
,
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
4844fac9
...
...
@@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk
=
False
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
top_k
=
config
.
num_experts_per_tok
,
...
...
python/sglang/srt/models/qwen3_next.py
View file @
4844fac9
...
...
@@ -30,7 +30,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.
ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.
fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -935,7 +935,7 @@ class Qwen3NextForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/step3_vl.py
View file @
4844fac9
...
...
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
use_grouped_topk
=
False
,
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment