Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f9bab3d5
Unverified
Commit
f9bab3d5
authored
May 26, 2025
by
Yi Zhang
Committed by
GitHub
May 25, 2025
Browse files
qwen3moe support two batch overlap (#6598)
parent
16f69b1f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
351 additions
and
28 deletions
+351
-28
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+17
-6
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+200
-11
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+98
-7
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+8
-4
test/srt/test_two_batch_overlap.py
test/srt/test_two_batch_overlap.py
+28
-0
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
f9bab3d5
...
...
@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.utils
import
add_prefix
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module):
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
forward_batch
.
can_run_tbo
:
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
,
enable_tbo
=
True
,
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
else
:
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
f9bab3d5
...
...
@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_non_idle_and_non_empty
Qwen3MoeConfig
=
None
...
...
@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
self
.
deepep_dispatcher
=
DeepEPDispatcher
(
self
.
deepep_dispatcher
=
MaybeTbo
DeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
...
...
@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
,
topk_idx
,
topk_weights
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_mode
=
forward_mode
,
)
final_hidden_states
=
self
.
experts
(
...
...
@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
,
topk_idx
,
topk_weights
,
forward_mode
,
hidden_states
=
final_hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_mode
=
forward_mode
,
)
return
final_hidden_states
def
op_gate
(
self
,
state
):
if
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
state
.
router_logits
,
_
=
self
.
gate
(
state
.
hidden_states_mlp_input
)
else
:
state
.
router_logits
=
None
def
op_select_experts
(
self
,
state
):
router_logits
=
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
state
.
topk_idx_local
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
state
.
topk_weights_local
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_dispatch_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
(
state
.
hidden_states_experts_input
,
state
.
topk_idx_dispatched
,
state
.
topk_weights_dispatched
,
state
.
reorder_topk_ids
,
state
.
num_recv_tokens_per_expert
,
state
.
seg_indptr
,
state
.
masked_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
topk_idx
=
state
.
topk_idx_dispatched
,
topk_weights
=
state
.
topk_weights_dispatched
,
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
)
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_output
(
self
,
state
):
state
.
hidden_states_mlp_output
=
state
.
pop
(
"hidden_states_after_combine"
)
class
Qwen3MoeAttention
(
nn
.
Module
):
def
__init__
(
...
...
@@ -339,20 +435,54 @@ class Qwen3MoeAttention(nn.Module):
k
=
k_by_head
.
view
(
k
.
shape
)
return
q
,
k
def
forward
(
def
op_prepare
(
self
,
state
):
state
.
attn_intermediate_state
=
self
.
forward_prepare
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
)
def
op_core
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
forward_core
(
state
.
pop
(
"attn_intermediate_state"
)
)
def
forward_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
):
if
hidden_states
.
shape
[
0
]
==
0
:
return
hidden_states
,
forward_batch
,
None
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
inner_state
=
q
,
k
,
v
,
forward_batch
return
None
,
forward_batch
,
inner_state
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
return
self
.
forward_core
(
s
)
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module):
return
hidden_states
,
residual
def
op_comm_prepare_attn
(
self
,
state
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
):
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
)
state
.
update
(
dict
(
forward_batch
=
forward_batch
,
positions
=
positions
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
)
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
state
.
pop
(
"hidden_states_after_attn"
),
state
.
pop
(
"residual_after_input_ln"
),
state
.
forward_batch
,
)
)
def
op_mlp
(
self
,
state
):
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
)
state
.
hidden_states_mlp_output
=
self
.
mlp
(
hidden_states
,
state
.
forward_batch
.
forward_mode
)
def
op_comm_postprocess_layer
(
self
,
state
):
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
state
.
pop
(
"hidden_states_mlp_output"
),
state
.
pop
(
"residual_after_comm_pre_mlp"
),
state
.
forward_batch
,
)
output
=
dict
(
positions
=
state
.
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
tbo_subbatch_index
,
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"tbo_subbatch_index"
,
}
)
return
output
class
Qwen3MoeModel
(
Qwen2MoeModel
):
def
__init__
(
...
...
python/sglang/srt/operations_strategy.py
View file @
f9bab3d5
...
...
@@ -32,12 +32,27 @@ class OperationsStrategy:
layers
:
torch
.
nn
.
ModuleList
,
forward_mode
:
ForwardMode
,
)
->
"OperationsStrategy"
:
return
OperationsStrategy
.
concat
(
[
_compute_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
layer_name
=
layers
[
0
].
__class__
.
__name__
if
layer_name
==
"DeepseekV2DecoderLayer"
:
return
OperationsStrategy
.
concat
(
[
_compute_moe_deepseek_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
elif
layer_name
==
"Qwen3MoeDecoderLayer"
:
return
OperationsStrategy
.
concat
(
[
_compute_moe_qwen3_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
else
:
raise
NotImplementedError
def
_assert_all_same
(
items
:
List
):
...
...
@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
return
items
[
0
]
# -------------------------------- Strategy for DeepSeek ---------------------------------------
# TODO can refactor to make it more fancy if we have more complex strategies
def
_compute_layer_operations_strategy_tbo
(
def
_compute_
moe_deepseek_
layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
forward_mode
:
ForwardMode
,
)
->
OperationsStrategy
:
...
...
@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
operations
.
YieldOperation
(),
],
)
# -------------------------------- Strategy for Qwen3 ---------------------------------------
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
# convenience to adjust strategy
def
_compute_moe_qwen3_layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
forward_mode
:
ForwardMode
,
)
->
OperationsStrategy
:
assert
layer
.
is_layer_sparse
,
"qwen3 moe only support sparse layers"
if
forward_mode
==
ForwardMode
.
EXTEND
:
return
_compute_moe_qwen3_prefill
(
layer
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
return
_compute_moe_qwen3_decode
(
layer
)
else
:
raise
NotImplementedError
(
f
"Unsupported
{
forward_mode
=
}
"
)
def
_compute_moe_qwen3_prefill
(
layer
):
device_properties
=
torch
.
cuda
.
get_device_properties
(
device
=
"cuda"
)
total_num_sms
=
device_properties
.
multi_processor_count
deep_gemm_num_sms
=
total_num_sms
-
DeepEPConfig
.
get_instance
().
num_sms
return
OperationsStrategy
(
deep_gemm_num_sms
=
deep_gemm_num_sms
,
tbo_delta_stages
=
0
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
],
)
def
_compute_moe_qwen3_decode
(
layer
):
return
OperationsStrategy
(
deep_gemm_num_sms
=
None
,
tbo_delta_stages
=
2
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
operations
.
YieldOperation
(),
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
operations
.
YieldOperation
(),
],
)
python/sglang/srt/two_batch_overlap.py
View file @
f9bab3d5
...
...
@@ -356,14 +356,14 @@ def model_forward_maybe_tbo(
forward_batch
:
ForwardBatch
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
):
inputs
=
dict
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
**
(
dict
(
zero_allocator
=
zero_allocator
)
if
zero_allocator
is
not
None
else
{})
,
)
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
layers
,
forward_batch
.
global_forward_mode
...
...
@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs(
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
)
->
List
[
Dict
]:
return
[
dict
(
...
...
@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs(
output_forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
),
zero_allocator
=
zero_allocator
,
**
(
dict
(
zero_allocator
=
zero_allocator
)
if
zero_allocator
is
not
None
else
{}
),
)
for
tbo_subbatch_index
,
output_forward_batch
in
enumerate
(
forward_batch
.
tbo_children
...
...
test/srt/test_two_batch_overlap.py
View file @
f9bab3d5
...
...
@@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
,
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
...
...
@@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase):
self
.
assertEqual
(
actual
,
expect
)
class
TestQwen3TwoBatchOverlap
(
TestTwoBatchOverlap
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-1234"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"2"
,
"--dp"
,
"2"
,
"--enable-dp-attention"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"normal"
,
"--disable-cuda-graph"
,
# DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap"
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment