Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
15ad6c90
Unverified
Commit
15ad6c90
authored
Jul 19, 2025
by
Cheng Wan
Committed by
GitHub
Jul 19, 2025
Browse files
[1/N] MoE Refactor: refactor `select_experts` (#7966)
parent
cfab0ff6
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
162 additions
and
110 deletions
+162
-110
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+22
-30
python/sglang/srt/models/granitemoe.py
python/sglang/srt/models/granitemoe.py
+8
-2
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+9
-3
python/sglang/srt/models/hunyuan.py
python/sglang/srt/models/hunyuan.py
+8
-5
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+11
-11
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+9
-2
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+8
-5
python/sglang/srt/models/phimoe.py
python/sglang/srt/models/phimoe.py
+9
-3
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+9
-5
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+13
-18
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+8
-3
python/sglang/test/test_block_fp8_ep.py
python/sglang/test/test_block_fp8_ep.py
+1
-1
python/sglang/test/test_cutlass_w4a8_moe.py
python/sglang/test/test_cutlass_w4a8_moe.py
+1
-3
python/sglang/test/test_fp4_moe.py
python/sglang/test/test_fp4_moe.py
+1
-3
test/srt/test_block_int8.py
test/srt/test_block_int8.py
+8
-3
test/srt/test_fused_moe.py
test/srt/test_fused_moe.py
+15
-4
test/srt/test_int8_kernel.py
test/srt/test_int8_kernel.py
+7
-3
test/srt/test_triton_moe_channel_fp8_kernel.py
test/srt/test_triton_moe_channel_fp8_kernel.py
+7
-3
test/srt/test_triton_moe_wna16.py
test/srt/test_triton_moe_wna16.py
+8
-3
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
15ad6c90
...
...
@@ -58,7 +58,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -303,6 +303,17 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
...
...
@@ -311,13 +322,7 @@ class DeepseekV2MoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
layer_id
=
self
.
layer_id
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
...
...
@@ -451,8 +456,9 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
@@ -473,8 +479,9 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
...
...
@@ -490,8 +497,9 @@ class DeepseekV2MoE(nn.Module):
)
->
torch
.
Tensor
:
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
fused_experts_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
assert
use_intel_amx_backend
(
...
...
@@ -549,17 +557,9 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_weights
,
topk_idx
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
True
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
correction_bias
=
self
.
correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
topk_weights
,
topk_idx
,
_
=
self
.
topk
(
hidden_states
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
...
...
@@ -649,17 +649,9 @@ class DeepseekV2MoE(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
topk_weights_local
,
state
.
topk_idx_local
=
sel
ect_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
,
_
=
sel
f
.
topk
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
True
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
correction_bias
=
self
.
correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
...
...
python/sglang/srt/models/granitemoe.py
View file @
15ad6c90
...
...
@@ -15,6 +15,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -60,6 +61,11 @@ class GraniteMoeMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.gate"
,
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
True
,
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
...
...
@@ -67,7 +73,6 @@ class GraniteMoeMoE(nn.Module):
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
,
...
...
@@ -78,7 +83,8 @@ class GraniteMoeMoE(nn.Module):
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
return
final_hidden_states
.
view
(
orig_shape
)
...
...
python/sglang/srt/models/grok.py
View file @
15ad6c90
...
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.router
import
fused_moe_router_shim
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -108,6 +109,12 @@ class Grok1MoE(nn.Module):
fused_moe_router_shim
,
self
.
router_logit_softcapping
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
False
,
custom_routing_function
=
custom_routing_function
,
)
kwargs
=
{}
if
global_server_args_dict
[
"enable_ep_moe"
]:
MoEImpl
=
EPMoE
...
...
@@ -124,17 +131,16 @@ class Grok1MoE(nn.Module):
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
custom_routing_function
=
custom_routing_function
,
activation
=
"gelu"
,
**
kwargs
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# need to assert self.gate.quant_method is unquantized
return
self
.
experts
(
hidden_states
,
self
.
gate
.
weight
)
topk_output
=
self
.
topk
(
hidden_states
,
self
.
gate
.
weight
)
return
self
.
experts
(
hidden_states
,
topk_output
)
class
Grok1Attention
(
nn
.
Module
):
...
...
python/sglang/srt/models/hunyuan.py
View file @
15ad6c90
...
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module):
else
config
.
moe_intermediate_size
[
layer_id
]
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
True
if
top_k
>
1
else
False
,
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
if
top_k
>
1
else
False
,
quant_config
=
quant_config
,
)
...
...
@@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
...
...
python/sglang/srt/models/llama4.py
View file @
15ad6c90
...
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -103,14 +104,17 @@ class Llama4MoE(nn.Module):
prefix
=
add_prefix
(
"router"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
self
.
top_k
,
renormalize
=
False
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
intermediate_size
=
intermediate_size_moe
,
reduce_results
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
apply_router_weight_on_input
=
True
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
...
...
@@ -147,10 +151,8 @@ class Llama4MoE(nn.Module):
# router_scores: [num_tokens, num_experts]
router_logits
,
_
=
self
.
router
(
hidden_states
)
shared_out
=
self
.
shared_expert
(
hidden_states
)
routed_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
routed_out
=
self
.
experts
(
hidden_states
,
topk_output
)
return
shared_out
,
routed_out
def
_forward_core_shared_routed_overlap
(
self
,
hidden_states
):
...
...
@@ -163,10 +165,8 @@ class Llama4MoE(nn.Module):
with
self
.
device_module
.
stream
(
alt_stream
):
# router_scores: [num_tokens, num_experts]
router_logits
,
_
=
self
.
router
(
hidden_states
)
routed_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
routed_out
=
self
.
experts
(
hidden_states
,
topk_output
)
self
.
device_module
.
current_stream
().
wait_stream
(
alt_stream
)
return
shared_out
,
routed_out
...
...
python/sglang/srt/models/mixtral.py
View file @
15ad6c90
...
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -86,6 +87,12 @@ class MixtralMoE(nn.Module):
quant_config
=
None
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
True
,
)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
MoEImpl
(
num_experts
=
num_experts
,
...
...
@@ -93,7 +100,6 @@ class MixtralMoE(nn.Module):
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
...
...
@@ -105,7 +111,8 @@ class MixtralMoE(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
orig_shape
)
...
...
python/sglang/srt/models/olmoe.py
View file @
15ad6c90
...
...
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module):
prefix
=
add_prefix
(
"gate"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
False
,
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
...
...
@@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
return
final_hidden_states
.
view
(
orig_shape
)
...
...
python/sglang/srt/models/phimoe.py
View file @
15ad6c90
...
...
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
quant_config
=
None
,
)
self
.
topk
=
TopK
(
top_k
=
top_k
,
renormalize
=
False
,
custom_routing_function
=
phimoe_routing_function
,
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
renormalize
=
False
,
quant_config
=
quant_config
,
custom_routing_function
=
phimoe_routing_function
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
...
...
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
return
final_hidden_states
.
view
(
orig_shape
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
15ad6c90
...
...
@@ -61,6 +61,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
config
.
norm_topk_prob
,
)
self
.
experts
=
get_moe_impl_class
()(
layer_id
=
self
.
layer_id
,
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
num_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
# Additional args for FusedMoE
...
...
@@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
15ad6c90
...
...
@@ -56,8 +56,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
False
,
)
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
...
...
@@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
layer_id
=
layer_id
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
...
...
@@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
]
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
...
...
@@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_weights
,
topk_idx
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
topk_weights
,
topk_idx
,
_
=
self
.
topk
(
hidden_states
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
...
...
@@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
topk_weights_local
,
state
.
topk_idx_local
=
sel
ect_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
,
_
=
sel
f
.
topk
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
...
...
python/sglang/test/test_block_fp8.py
View file @
15ad6c90
...
...
@@ -6,6 +6,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_tensor_quant_mla_fp8
,
per_token_group_quant_fp8
,
...
...
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
torch
.
inference_mode
():
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
renormalize
=
False
,
)
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
...
...
python/sglang/test/test_block_fp8_ep.py
View file @
15ad6c90
...
...
@@ -40,7 +40,7 @@ def ep_moe(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
use_blockwise_fp8
=
block_shape
is
not
None
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
,
_
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
top_k
,
...
...
python/sglang/test/test_cutlass_w4a8_moe.py
View file @
15ad6c90
...
...
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
s_strides2
=
c_strides2
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
,
device
=
device
)
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
,
_
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
use_grouped_topk
=
False
,
renormalize
=
False
,
)
expert_map
=
torch
.
arange
(
E
,
dtype
=
torch
.
int32
,
device
=
device
)
expert_map
[
local_e
:]
=
E
...
...
python/sglang/test/test_fp4_moe.py
View file @
15ad6c90
...
...
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
,
_
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
use_grouped_topk
=
False
,
renormalize
=
False
,
)
a1_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
...
test/srt/test_block_int8.py
View file @
15ad6c90
...
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -171,14 +172,18 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
with
torch
.
inference_mode
():
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_int8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
...
...
test/srt/test_fused_moe.py
View file @
15ad6c90
...
...
@@ -6,6 +6,7 @@ from tqdm import tqdm
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.utils
import
is_hip
...
...
@@ -132,13 +133,17 @@ class TestFusedMOE(CustomTestCase):
input_scale
=
a2_scale
,
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
sglang_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
...
...
@@ -166,7 +171,13 @@ class TestFusedMOE(CustomTestCase):
w2
=
self
.
create_random_cuda_tensor
((
e
,
k
,
n
),
dtype
)
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_output
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
...
...
test/srt/test_int8_kernel.py
View file @
15ad6c90
...
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -114,13 +115,16 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
with
torch
.
inference_mode
():
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_fp8_w8a8
=
False
,
# Not using fp8
use_int8_w8a16
=
False
,
# Not using int8-w8a16
use_int8_w8a8
=
True
,
# Using int8-w8a8
...
...
test/srt/test_triton_moe_channel_fp8_kernel.py
View file @
15ad6c90
...
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -126,13 +127,16 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
with
torch
.
inference_mode
():
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_fp8_w8a8
=
True
,
# using fp8
use_int8_w8a16
=
False
,
use_int8_w8a8
=
False
,
...
...
test/srt/test_triton_moe_wna16.py
View file @
15ad6c90
...
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
select_experts
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
6
]
...
...
@@ -219,13 +220,17 @@ def test_fused_moe_wn16(
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
topk_output
=
select_experts
(
hidden_states
=
a
,
router_logits
=
score
,
top_k
=
topk
,
)
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w2_qweight
,
score
,
topk
,
renormalize
=
False
,
topk_output
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
w1_scales
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment