Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
29589512
Unverified
Commit
29589512
authored
Aug 14, 2025
by
Cheng Wan
Committed by
GitHub
Aug 14, 2025
Browse files
[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
parent
584e1ab2
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
162 additions
and
268 deletions
+162
-268
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+11
-49
python/sglang/srt/models/ernie4.py
python/sglang/srt/models/ernie4.py
+2
-2
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+18
-59
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+2
-11
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+16
-34
python/sglang/srt/models/granitemoe.py
python/sglang/srt/models/granitemoe.py
+0
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+0
-1
python/sglang/srt/models/interns1.py
python/sglang/srt/models/interns1.py
+2
-1
python/sglang/srt/models/internvl.py
python/sglang/srt/models/internvl.py
+2
-2
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+0
-2
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+0
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+0
-2
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+0
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+3
-18
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+9
-38
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+2
-1
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+66
-27
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+16
-11
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-2
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
29589512
...
...
@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_deepep_mode
,
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
])
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
else
{}
),
# Additional args for FusedMoE
**
(
dict
(
enable_flashinfer_cutlass_moe
=
True
,
)
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]
else
{}
),
**
(
dict
(
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
if
should_use_flashinfer_trtllm_moe
()
else
{}
),
)
self
.
shared_experts_is_int8
=
False
...
...
@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
prefix
=
add_prefix
(
"shared_experts"
,
prefix
),
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
()
if
g
et_
moe_a2a_backend
()
.
is_deepep
()
else
{}
),
)
...
...
@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
self
.
top_k
=
config
.
num_experts_per_tok
if
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
g
et_
moe_a2a_backend
()
.
is_deepep
():
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
...
...
@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
g
lobal_server_args_dict
[
"
deepep_mode
"
]
,
deepep_mode
=
g
et_
deepep_mode
()
,
async_finish
=
True
,
return_recv_hook
=
True
,
)
self
.
_enable_deepep_moe
=
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
()
self
.
_enable_deepep_moe
=
g
et_
moe_a2a_backend
()
.
is_deepep
()
def
get_moe_weights
(
self
):
return
[
...
...
@@ -484,13 +460,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
...
...
@@ -520,13 +490,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
...
...
@@ -2478,17 +2442,15 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
,
)
if
self
.
quant_config
and
self
.
quant_config
.
get_name
()
==
"w4afp8"
:
expert_params_mapping
+=
(
get_moe_impl_class
().
make_expert_input_scale_params_mapping
(
num_experts
=
self
.
config
.
n_routed_experts
)
expert_params_mapping
+=
FusedMoE
.
make_expert_input_scale_params_mapping
(
num_experts
=
self
.
config
.
n_routed_experts
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
...
...
python/sglang/srt/models/ernie4.py
View file @
29589512
...
...
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2MLP
as
Ernie4MLP
...
...
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
class
Ernie4_5_MoeForCausalLM
(
Ernie4_5_ForCausalLM
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/glm4_moe.py
View file @
29589512
...
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_deepep_mode
,
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
...
...
@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model
,
DeepseekV2MoE
,
)
from
sglang.srt.two_batch_overlap
import
(
MaybeTboDeepEPDispatcher
,
model_forward_maybe_tbo
,
)
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
BumpAllocator
,
LazyValue
,
...
...
@@ -414,19 +411,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
(
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
not
should_use_flashinfer_trtllm_moe
()
else
None
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
...
...
@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
])
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
else
{}
),
# Additional args for FusedMoE
**
(
dict
(
enable_flashinfer_cutlass_moe
=
True
,
)
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]
else
{}
),
**
(
dict
(
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
if
should_use_flashinfer_trtllm_moe
()
else
{}
),
)
self
.
shared_experts_is_int8
=
False
...
...
@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self
.
top_k
=
config
.
num_experts_per_tok
if
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
g
et_
moe_a2a_backend
()
.
is_deepep
():
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
...
...
@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
g
lobal_server_args_dict
[
"
deepep_mode
"
]
,
deepep_mode
=
g
et_
deepep_mode
()
,
async_finish
=
True
,
return_recv_hook
=
True
,
)
self
.
_enable_deepep_moe
=
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
()
self
.
_enable_deepep_moe
=
g
et_
moe_a2a_backend
()
.
is_deepep
()
def
forward_normal_dual_stream
(
self
,
...
...
@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
...
...
@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
dp_size
=
get_local_attention_dp_size
()
class
Glm4MoeForCausalLM
(
DeepseekV2ForCausalLM
):
...
...
@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
],
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
dp_size
=
get_local_attention_dp_size
()
self
.
_routed_experts_weights_of_layer
=
LazyValue
(
lambda
:
{
...
...
@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/glm4v_moe.py
View file @
29589512
...
...
@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.
ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.
fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
...
...
@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config
.
moe_layer_freq
=
1
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
dp_size
=
get_local_attention_dp_size
()
self
.
quant_config
=
quant_config
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
num_fused_shared_experts
=
(
...
...
@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/gpt_oss.py
View file @
29589512
...
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
dequant_mxfp4
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -110,16 +110,13 @@ class GptOssSparseMoeBlock(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_id
=
layer_id
self
.
activation
=
config
.
hidden_act
self
.
activation
_alpha
=
getattr
(
config
,
"hidden_act_alpha"
,
1.702
)
self
.
swiglu
_limit
=
config
.
swiglu_limit
self
.
gemm1
_alpha
=
getattr
(
config
,
"hidden_act_alpha"
,
1.702
)
self
.
gemm1_clamp
_limit
=
config
.
swiglu_limit
if
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]:
self
.
topk
=
None
else
:
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
self
.
top_k
=
config
.
num_experts_per_tok
experts_type
=
get_moe_impl_class
()
...
...
@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
quant_config
.
get_name
()
if
quant_config
is
not
None
else
None
)
extra_kwargs
=
{
"enable_flashinfer_cutlass_moe"
:
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused"
:
quant_config_name
!=
"mxfp4"
,
"use_weight_loader_fused"
:
quant_config_name
!=
"mxfp4"
}
self
.
experts
=
experts_type
(
num_experts
=
config
.
num_local_experts
...
...
@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
activation
=
self
.
activation
,
activation
_alpha
=
self
.
activation
_alpha
,
swiglu
_limit
=
self
.
swiglu
_limit
,
gemm1
_alpha
=
self
.
gemm1
_alpha
,
gemm1_clamp
_limit
=
self
.
gemm1_clamp
_limit
,
with_bias
=
True
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]])
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
else
{}
),
**
extra_kwargs
,
)
...
...
@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
not
g
et_
moe_a2a_backend
()
.
is_deepep
():
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
)
else
:
raise
Exception
(
"forward_deepep branch not implemented yet"
)
...
...
@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
should_allreduce_fusion
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
router
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"topk_output"
]
=
(
self
.
top_k
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
# GptOss all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
...
...
@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping_fused
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping_fused
(
ckpt_gate_up_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_gate_up_proj_bias_name
=
"gate_up_proj_bias"
,
...
...
python/sglang/srt/models/granitemoe.py
View file @
29589512
...
...
@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
params_dtype
=
params_dtype
,
reduce_results
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
...
...
python/sglang/srt/models/grok.py
View file @
29589512
...
...
@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
activation
=
"gelu"
,
**
kwargs
,
)
...
...
python/sglang/srt/models/interns1.py
View file @
29589512
...
...
@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
...
...
@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
]
expert_params_mapping
=
[]
if
"Qwen3MoeForCausalLM"
in
self
.
config
.
text_config
.
architectures
:
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/internvl.py
View file @
29589512
...
...
@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.layers.attention.vision
import
SingletonCache
,
VisionAttention
from
sglang.srt.layers.moe.
ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.
fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
...
...
@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/llama4.py
View file @
29589512
...
...
@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
rope_theta
=
config
.
rope_theta
rope_scaling
=
config
.
rope_scaling
max_position_embeddings
=
config
.
max_position_embeddings
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
...
...
python/sglang/srt/models/minicpm3.py
View file @
29589512
...
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cuda
...
...
python/sglang/srt/models/mixtral.py
View file @
29589512
...
...
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
...
...
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
...
...
python/sglang/srt/models/olmoe.py
View file @
29589512
...
...
@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
layer_id
=
layer_id
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
29589512
...
...
@@ -17,8 +17,6 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.eplb.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
(
...
...
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
,
get_moe_impl_class
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
...
...
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size
=
config
.
moe_intermediate_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
# Additional args for FusedMoE
**
(
dict
(
enable_flashinfer_cutlass_moe
=
True
,
)
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]
else
{}
),
)
self
.
gate
=
ReplicatedLinear
(
...
...
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
# Qwen2MoE all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
29589512
...
...
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
parallel_state
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
Qwen3MoeConfig
=
None
...
...
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
intermediate_size
=
config
.
moe_intermediate_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
global_server_args_dict
[
"deepep_mode"
])
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
else
{}
),
# Additional args for FusedMoE
**
(
dict
(
enable_flashinfer_cutlass_moe
=
True
,
)
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]
else
{}
),
)
self
.
gate
=
ReplicatedLinear
(
...
...
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix
=
add_prefix
(
"gate"
,
prefix
),
)
if
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
g
et_
moe_a2a_backend
()
.
is_deepep
():
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
...
...
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
not
g
et_
moe_a2a_backend
()
.
is_deepep
():
return
self
.
forward_normal
(
hidden_states
,
use_reduce_scatter
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
# Qwen3MoE all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
...
...
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
get_moe_impl_class
()
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/step3_vl.py
View file @
29589512
...
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
...
...
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
prefix
=
add_prefix
(
"gate"
,
prefix
),
)
if
g
lobal_server_args_dict
[
"
moe_a2a_backend
"
]
.
is_deepep
():
if
g
et_
moe_a2a_backend
()
.
is_deepep
():
raise
NotImplementedError
(
"DeepEP MoE is not supported yet in Step3 model."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/models/xverse_moe.py
View file @
29589512
...
...
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
...
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
]
)
self
.
pack_params
()
self
.
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
True
)
self
.
router
=
ReplicatedLinear
(
config
.
hidden_size
,
...
...
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
quant_config
=
None
,
prefix
=
add_prefix
(
"router"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
self
.
top_k
,
renormalize
=
getattr
(
self
.
config
,
"norm_topk_prob"
,
False
),
)
if
config
.
num_shared_experts
is
not
None
:
intermediate_size
=
config
.
intermediate_size
*
config
.
num_shared_experts
...
...
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
router
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
router_logits
,
self
.
top_k
,
renormalize
=
getattr
(
self
.
config
,
"norm_topk_prob"
,
False
),
inplace
=
True
,
topk_output
,
self
.
moe_runner_config
,
)
if
self
.
config
.
num_shared_experts
is
not
None
:
...
...
python/sglang/srt/server_args.py
View file @
29589512
...
...
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip
,
is_port_available
,
is_remote_url
,
is_triton_kernels_available
,
is_valid_ipv6_address
,
nullable_str
,
)
...
...
@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism
ep_size
:
int
=
1
moe_a2a_backend
:
Optional
[
Literal
[
"deepep"
]]
=
None
enable_flashinfer_cutlass_moe
:
bool
=
False
enable_flashinfer_trtllm_moe
:
bool
=
False
moe_a2a_backend
:
Literal
[
"none"
,
"deepep"
]
=
"none"
moe_runner_backend
:
Literal
[
"auto"
,
"triton"
,
"triton_kernel"
,
"flashinfer_trtllm"
,
"flashinfer_cutlass"
,
"flashinfer_mxfp4"
,
]
=
"auto"
enable_flashinfer_allreduce_fusion
:
bool
=
False
deepep_mode
:
Literal
[
"auto"
,
"normal"
,
"low_latency"
]
=
"auto"
ep_num_redundant_experts
:
int
=
0
...
...
@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache
:
bool
=
False
disable_fast_image_processor
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
enable_flashinfer_mxfp4_moe
:
bool
=
False
scheduler_recv_interval
:
int
=
1
# Debug tensor dumps
...
...
@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_flashinfer_cutlass_moe
:
bool
=
False
enable_flashinfer_trtllm_moe
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
def
__post_init__
(
self
):
# Check deprecated arguments
...
...
@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning
(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if
self
.
enable_triton_kernel_moe
:
self
.
moe_runner_backend
=
"triton_kernel"
print_deprecated_warning
(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if
self
.
enable_flashinfer_cutlass_moe
:
self
.
moe_runner_backend
=
"flashinfer_cutlass"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if
self
.
enable_flashinfer_trtllm_moe
:
self
.
moe_runner_backend
=
"flashinfer_trtllm"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values
if
self
.
tokenizer_path
is
None
:
...
...
@@ -517,7 +540,7 @@ class ServerArgs:
),
"Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if
self
.
enable_
flashinfer_cutlass
_moe
:
if
self
.
moe_runner_backend
==
"
flashinfer_cutlass
"
:
assert
(
self
.
quantization
==
"modelopt_fp4"
),
"modelopt_fp4 quantization is required for Flashinfer MOE"
...
...
@@ -527,7 +550,7 @@ class ServerArgs:
self
.
tp_size
,
],
"The expert parallel size must be 1 or the same as the tensor parallel size"
if
self
.
enable_
flashinfer_trtllm
_moe
:
if
self
.
moe_runner_backend
==
"
flashinfer_trtllm
"
:
if
not
self
.
disable_shared_experts_fusion
:
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
...
...
@@ -556,7 +579,7 @@ class ServerArgs:
self
.
ep_dispatch_algorithm
=
"static"
if
self
.
enable_eplb
:
assert
self
.
ep_size
>
1
or
self
.
moe_a2a_backend
is
not
None
assert
self
.
ep_size
>
1
if
self
.
enable_expert_distribution_metrics
and
(
self
.
expert_distribution_recorder_mode
is
None
...
...
@@ -1446,19 +1469,22 @@ class ServerArgs:
parser
.
add_argument
(
"--moe-a2a-backend"
,
type
=
str
,
choices
=
[
"deepep"
],
choices
=
[
"none"
,
"deepep"
],
default
=
ServerArgs
.
moe_a2a_backend
,
help
=
"Choose the backend for MoE A2A."
,
)
parser
.
add_argument
(
"--enable-flashinfer-cutlass-moe"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP"
,
)
parser
.
add_argument
(
"--enable-flashinfer-trtllm-moe"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP"
,
"--moe-runner-backend"
,
type
=
str
,
choices
=
[
"auto"
,
"triton"
,
"triton_kernel"
,
"flashinfer_trtllm"
,
"flashinfer_cutlass"
,
],
default
=
ServerArgs
.
moe_runner_backend
,
help
=
"Choose the runner backend for MoE."
,
)
parser
.
add_argument
(
"--enable-flashinfer-allreduce-fusion"
,
...
...
@@ -1825,11 +1851,6 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Enable returning hidden states with responses."
,
)
parser
.
add_argument
(
"--enable-triton-kernel-moe"
,
action
=
"store_true"
,
help
=
"Use triton moe grouped gemm kernel."
,
)
parser
.
add_argument
(
"--enable-flashinfer-mxfp4-moe"
,
action
=
"store_true"
,
...
...
@@ -1965,6 +1986,21 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"(Deprecated) Enabling DeepEP MoE implementation for EP MoE."
,
)
parser
.
add_argument
(
"--enable-flashinfer-cutlass-moe"
,
action
=
"store_true"
,
help
=
"(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP"
,
)
parser
.
add_argument
(
"--enable-flashinfer-trtllm-moe"
,
action
=
"store_true"
,
help
=
"(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP"
,
)
parser
.
add_argument
(
"--enable-triton-kernel-moe"
,
action
=
"store_true"
,
help
=
"(Deprecated) Use triton moe grouped gemm kernel."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
@@ -2143,18 +2179,21 @@ class ServerArgs:
)
if
is_sm100_supported
()
and
is_mxfp4_quant_format
:
self
.
enable_flashinfer_mxfp4_moe
=
True
self
.
enable_triton_kernel_moe
=
False
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
logger
.
warning
(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else
:
if
self
.
enable_
triton_kernel
_moe
:
if
self
.
moe_runner_backend
==
"
triton_kernel
"
:
assert
(
self
.
ep_size
==
1
),
"Triton kernel MoE is only supported when ep_size == 1"
if
not
self
.
enable_triton_kernel_moe
and
self
.
ep_size
==
1
:
self
.
enable_triton_kernel_moe
=
True
if
(
self
.
moe_runner_backend
==
"auto"
and
self
.
ep_size
==
1
and
is_triton_kernels_available
()
):
self
.
moe_runner_backend
=
"triton_kernel"
logger
.
warning
(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
...
...
python/sglang/srt/two_batch_overlap.py
View file @
29589512
...
...
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn
,
ScatterMode
,
)
from
sglang.srt.layers.moe
import
(
get_deepep_mode
,
get_moe_a2a_backend
,
get_tbo_token_distribution_threshold
,
is_tbo_enabled
,
)
from
sglang.srt.layers.moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
...
...
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
vanilla_split_seq_index
=
_split_array_by_balanced_sum
(
extend_lens
)
left_sum
=
sum
(
extend_lens
[:
vanilla_split_seq_index
])
overall_sum
=
sum
(
extend_lens
)
threshold
=
g
lobal_server_args_dict
[
"
tbo_token_distribution_threshold
"
]
threshold
=
g
et_
tbo_token_distribution_threshold
()
assert
threshold
<=
0.5
,
f
"
{
threshold
=
}
"
return
left_sum
<
overall_sum
*
threshold
or
left_sum
>
overall_sum
*
(
1
-
threshold
...
...
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
self
.
_tbo_children_num_token_non_padded
=
torch
.
zeros
((
2
,),
dtype
=
torch
.
int32
)
def
capture_one_batch_size
(
self
,
batch
:
ForwardBatch
,
num_tokens
:
int
):
if
not
global_server_args_dict
[
"enable_two_batch_overlap"
]
:
if
not
is_tbo_enabled
()
:
return
token_num_per_seq
=
get_token_num_per_seq
(
forward_mode
=
batch
.
forward_mode
,
spec_info
=
batch
.
spec_info
...
...
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
def
prepare_all_gather
(
self
,
local_batch
:
ScheduleBatch
,
deepep_mode
:
DeepEPMode
,
enable_deepep_moe
:
bool
,
enable_two_batch_overlap
:
bool
,
):
deepep_mode
=
get_deepep_mode
()
enable_deepep_moe
=
get_moe_a2a_backend
().
is_deepep
()
enable_two_batch_overlap
=
is_tbo_enabled
()
self
.
enable_two_batch_overlap
=
enable_two_batch_overlap
if
local_batch
is
not
None
:
...
...
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
and
not
local_batch
.
forward_mode
.
is_target_verify
()
)
and
enable_deepep_moe
and
(
resolved_deepep_mode
==
DeepEPMode
.
LOW_LATENCY
)
and
(
resolved_deepep_mode
.
is_low_latency
()
)
)
else
:
self
.
local_tbo_split_seq_index
=
0
...
...
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
"req_to_token_pool"
,
"token_to_kv_pool"
,
"can_run_dp_cuda_graph"
,
"dp_padding_mode"
,
"global_forward_mode"
,
"spec_algorithm"
,
"capture_hidden_mode"
,
...
...
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
tbo_children
=
None
,
global_num_tokens_gpu
=
None
,
global_num_tokens_cpu
=
None
,
dp_padding_mode
=
None
,
global_dp_buffer_len
=
global_dp_buffer_len
,
global_num_tokens_for_logprob_gpu
=
None
,
global_num_tokens_for_logprob_cpu
=
None
,
...
...
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class
MaybeTboDeepEPDispatcher
:
def
__init__
(
self
,
**
kwargs
):
num_inner_dispatchers
=
(
2
if
global_server_args_dict
[
"enable_two_batch_overlap"
]
else
1
)
num_inner_dispatchers
=
2
if
is_tbo_enabled
()
else
1
self
.
_inners
=
[
DeepEPDispatcher
(
**
kwargs
)
for
_
in
range
(
num_inner_dispatchers
)
]
...
...
python/sglang/srt/utils.py
View file @
29589512
...
...
@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
return
True
elif
not
server_args
.
enable_dp_lm_head
:
return
True
elif
server_args
.
moe_a2a_backend
is
N
one
:
elif
server_args
.
moe_a2a_backend
==
"n
one
"
:
return
True
else
:
return
(
...
...
@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered.
"""
assert
server_args
.
moe_dense_tp_size
in
[
1
,
None
]
if
server_args
.
moe_a2a_backend
is
not
N
one
or
server_args
.
moe_dense_tp_size
==
1
:
if
server_args
.
moe_a2a_backend
!=
"n
one
"
or
server_args
.
moe_dense_tp_size
==
1
:
if
server_args
.
enable_dp_attention
:
return
server_args
.
dp_size
<
server_args
.
tp_size
else
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment