Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f9bab3d5
"examples/mxnet/vscode:/vscode.git/clone" did not exist on "fdd0fe651dbc3e19fb3a3b1a2e4df81f74884ae2"
Unverified
Commit
f9bab3d5
authored
May 26, 2025
by
Yi Zhang
Committed by
GitHub
May 25, 2025
Browse files
qwen3moe support two batch overlap (#6598)
parent
16f69b1f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
351 additions
and
28 deletions
+351
-28
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+17
-6
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+200
-11
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+98
-7
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+8
-4
test/srt/test_two_batch_overlap.py
test/srt/test_two_batch_overlap.py
+28
-0
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
f9bab3d5
...
@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
...
@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module):
...
@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module):
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
forward_batch
.
can_run_tbo
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layer
=
self
.
layers
[
i
]
layers
=
self
.
layers
,
hidden_states
,
residual
=
layer
(
enable_tbo
=
True
,
positions
,
hidden_states
,
forward_batch
,
residual
positions
=
positions
,
)
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
else
:
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
return
PPProxyTensors
(
{
{
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
f9bab3d5
...
@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_non_idle_and_non_empty
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_non_idle_and_non_empty
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
...
@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
.
top_k
=
config
.
num_experts_per_tok
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
self
.
renormalize
=
config
.
norm_topk_prob
self
.
deepep_dispatcher
=
DeepEPDispatcher
(
self
.
deepep_dispatcher
=
MaybeTbo
DeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
permute_fusion
=
True
,
...
@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
,
topk_weights
=
topk_weights
,
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
)
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
...
@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
)
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
,
hidden_states
=
final_hidden_states
,
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
,
topk_weights
=
topk_weights
,
forward_mode
,
forward_mode
=
forward_mode
,
)
)
return
final_hidden_states
return
final_hidden_states
def
op_gate
(
self
,
state
):
if
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
state
.
router_logits
,
_
=
self
.
gate
(
state
.
hidden_states_mlp_input
)
else
:
state
.
router_logits
=
None
def
op_select_experts
(
self
,
state
):
router_logits
=
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
state
.
topk_idx_local
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
state
.
topk_weights_local
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_dispatch_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
(
state
.
hidden_states_experts_input
,
state
.
topk_idx_dispatched
,
state
.
topk_weights_dispatched
,
state
.
reorder_topk_ids
,
state
.
num_recv_tokens_per_expert
,
state
.
seg_indptr
,
state
.
masked_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
topk_idx
=
state
.
topk_idx_dispatched
,
topk_weights
=
state
.
topk_weights_dispatched
,
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
)
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_output
(
self
,
state
):
state
.
hidden_states_mlp_output
=
state
.
pop
(
"hidden_states_after_combine"
)
class
Qwen3MoeAttention
(
nn
.
Module
):
class
Qwen3MoeAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -339,20 +435,54 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -339,20 +435,54 @@ class Qwen3MoeAttention(nn.Module):
k
=
k_by_head
.
view
(
k
.
shape
)
k
=
k_by_head
.
view
(
k
.
shape
)
return
q
,
k
return
q
,
k
def
forward
(
def
op_prepare
(
self
,
state
):
state
.
attn_intermediate_state
=
self
.
forward_prepare
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
)
def
op_core
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
forward_core
(
state
.
pop
(
"attn_intermediate_state"
)
)
def
forward_prepare
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
):
if
hidden_states
.
shape
[
0
]
==
0
:
return
hidden_states
,
forward_batch
,
None
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
inner_state
=
q
,
k
,
v
,
forward_batch
return
None
,
forward_batch
,
inner_state
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
return
self
.
forward_core
(
s
)
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
def
op_comm_prepare_attn
(
self
,
state
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
):
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
)
state
.
update
(
dict
(
forward_batch
=
forward_batch
,
positions
=
positions
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
)
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
state
.
pop
(
"hidden_states_after_attn"
),
state
.
pop
(
"residual_after_input_ln"
),
state
.
forward_batch
,
)
)
def
op_mlp
(
self
,
state
):
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
)
state
.
hidden_states_mlp_output
=
self
.
mlp
(
hidden_states
,
state
.
forward_batch
.
forward_mode
)
def
op_comm_postprocess_layer
(
self
,
state
):
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
state
.
pop
(
"hidden_states_mlp_output"
),
state
.
pop
(
"residual_after_comm_pre_mlp"
),
state
.
forward_batch
,
)
output
=
dict
(
positions
=
state
.
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
tbo_subbatch_index
,
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"tbo_subbatch_index"
,
}
)
return
output
class
Qwen3MoeModel
(
Qwen2MoeModel
):
class
Qwen3MoeModel
(
Qwen2MoeModel
):
def
__init__
(
def
__init__
(
...
...
python/sglang/srt/operations_strategy.py
View file @
f9bab3d5
...
@@ -32,12 +32,27 @@ class OperationsStrategy:
...
@@ -32,12 +32,27 @@ class OperationsStrategy:
layers
:
torch
.
nn
.
ModuleList
,
layers
:
torch
.
nn
.
ModuleList
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
)
->
"OperationsStrategy"
:
)
->
"OperationsStrategy"
:
return
OperationsStrategy
.
concat
(
layer_name
=
layers
[
0
].
__class__
.
__name__
[
if
layer_name
==
"DeepseekV2DecoderLayer"
:
_compute_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
return
OperationsStrategy
.
concat
(
for
layer
in
layers
[
]
_compute_moe_deepseek_layer_operations_strategy_tbo
(
)
layer
,
forward_mode
)
for
layer
in
layers
]
)
elif
layer_name
==
"Qwen3MoeDecoderLayer"
:
return
OperationsStrategy
.
concat
(
[
_compute_moe_qwen3_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
else
:
raise
NotImplementedError
def
_assert_all_same
(
items
:
List
):
def
_assert_all_same
(
items
:
List
):
...
@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
...
@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
return
items
[
0
]
return
items
[
0
]
# -------------------------------- Strategy for DeepSeek ---------------------------------------
# TODO can refactor to make it more fancy if we have more complex strategies
# TODO can refactor to make it more fancy if we have more complex strategies
def
_compute_layer_operations_strategy_tbo
(
def
_compute_
moe_deepseek_
layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
)
->
OperationsStrategy
:
)
->
OperationsStrategy
:
...
@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
...
@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
operations
.
YieldOperation
(),
operations
.
YieldOperation
(),
],
],
)
)
# -------------------------------- Strategy for Qwen3 ---------------------------------------
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
# convenience to adjust strategy
def
_compute_moe_qwen3_layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
forward_mode
:
ForwardMode
,
)
->
OperationsStrategy
:
assert
layer
.
is_layer_sparse
,
"qwen3 moe only support sparse layers"
if
forward_mode
==
ForwardMode
.
EXTEND
:
return
_compute_moe_qwen3_prefill
(
layer
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
return
_compute_moe_qwen3_decode
(
layer
)
else
:
raise
NotImplementedError
(
f
"Unsupported
{
forward_mode
=
}
"
)
def
_compute_moe_qwen3_prefill
(
layer
):
device_properties
=
torch
.
cuda
.
get_device_properties
(
device
=
"cuda"
)
total_num_sms
=
device_properties
.
multi_processor_count
deep_gemm_num_sms
=
total_num_sms
-
DeepEPConfig
.
get_instance
().
num_sms
return
OperationsStrategy
(
deep_gemm_num_sms
=
deep_gemm_num_sms
,
tbo_delta_stages
=
0
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
],
)
def
_compute_moe_qwen3_decode
(
layer
):
return
OperationsStrategy
(
deep_gemm_num_sms
=
None
,
tbo_delta_stages
=
2
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
operations
.
YieldOperation
(),
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
operations
.
YieldOperation
(),
],
)
python/sglang/srt/two_batch_overlap.py
View file @
f9bab3d5
...
@@ -356,14 +356,14 @@ def model_forward_maybe_tbo(
...
@@ -356,14 +356,14 @@ def model_forward_maybe_tbo(
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
):
):
inputs
=
dict
(
inputs
=
dict
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
residual
=
residual
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
**
(
dict
(
zero_allocator
=
zero_allocator
)
if
zero_allocator
is
not
None
else
{})
,
)
)
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
layers
,
forward_batch
.
global_forward_mode
layers
,
forward_batch
.
global_forward_mode
...
@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs(
...
@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs(
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
)
->
List
[
Dict
]:
)
->
List
[
Dict
]:
return
[
return
[
dict
(
dict
(
...
@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs(
...
@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs(
output_forward_batch
=
output_forward_batch
,
output_forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
tbo_subbatch_index
=
tbo_subbatch_index
,
),
),
zero_allocator
=
zero_allocator
,
**
(
dict
(
zero_allocator
=
zero_allocator
)
if
zero_allocator
is
not
None
else
{}
),
)
)
for
tbo_subbatch_index
,
output_forward_batch
in
enumerate
(
for
tbo_subbatch_index
,
output_forward_batch
in
enumerate
(
forward_batch
.
tbo_children
forward_batch
.
tbo_children
...
...
test/srt/test_two_batch_overlap.py
View file @
f9bab3d5
...
@@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index
...
@@ -9,6 +9,7 @@ from sglang.srt.two_batch_overlap import compute_split_seq_index
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
,
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
...
@@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase):
...
@@ -104,5 +105,32 @@ class TestTwoBatchOverlapUnitTest(unittest.TestCase):
self
.
assertEqual
(
actual
,
expect
)
self
.
assertEqual
(
actual
,
expect
)
class
TestQwen3TwoBatchOverlap
(
TestTwoBatchOverlap
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-1234"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"2"
,
"--dp"
,
"2"
,
"--enable-dp-attention"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"normal"
,
"--disable-cuda-graph"
,
# DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap"
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment