Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8fc910db
Unverified
Commit
8fc910db
authored
Jul 05, 2025
by
Cheng Wan
Committed by
GitHub
Jul 05, 2025
Browse files
DP Attention with Auto DeepEP Dispatch (#7222)
parent
75354d9a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
136 additions
and
90 deletions
+136
-90
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+1
-1
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-2
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+5
-3
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+15
-13
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-4
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-7
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+7
-9
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-4
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+7
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-4
test/srt/test_hybrid_dp_ep_tp_mtp.py
test/srt/test_hybrid_dp_ep_tp_mtp.py
+80
-40
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
8fc910db
...
@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
self
.
last_batch_in_queue
=
last_batch_in_queue
self
.
last_batch_in_queue
=
last_batch_in_queue
def
_prepare_idle_batch_and_run
(
self
:
Scheduler
,
batch
,
delay_process
=
False
):
def
_prepare_idle_batch_and_run
(
self
:
Scheduler
,
batch
,
delay_process
=
False
):
batch
,
_
=
self
.
prepare_mlp_sync_batch
(
batch
)
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
result
=
None
result
=
None
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
8fc910db
...
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
batch
=
self
.
get_new_batch_prefill
()
batch
=
self
.
get_new_batch_prefill
()
if
require_mlp_sync
(
self
.
server_args
):
if
require_mlp_sync
(
self
.
server_args
):
batch
,
_
=
self
.
prepare_mlp_sync_batch
(
batch
)
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
...
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
batch
=
self
.
get_new_batch_prefill
()
batch
=
self
.
get_new_batch_prefill
()
if
require_mlp_sync
(
self
.
server_args
):
if
require_mlp_sync
(
self
.
server_args
):
batch
,
_
=
self
.
prepare_mlp_sync_batch
(
batch
)
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8fc910db
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DeepEPMode
,
DeepEPMode
,
ceil_div
,
ceil_div
,
...
@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
...
@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
masked_m
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
expected_m
:
int
,
num_recv_tokens_per_expert
:
List
[
int
],
num_recv_tokens_per_expert
:
List
[
int
],
forward_
mode
:
Forward
Mode
,
forward_
batch
:
Forward
Batch
,
):
):
if
_use_aiter
:
if
_use_aiter
:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return
self
.
forward_aiter
(
hidden_states
,
topk_idx
,
topk_weights
)
return
self
.
forward_aiter
(
hidden_states
,
topk_idx
,
topk_weights
)
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_batch
.
is_extend_in_batch
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
return
self
.
forward_deepgemm_contiguous
(
return
self
.
forward_deepgemm_contiguous
(
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
8fc910db
...
@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
...
@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel
,
deepep_post_reorder_triton_kernel
,
deepep_run_moe_deep_preprocess
,
deepep_run_moe_deep_preprocess
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Mode
from
sglang.srt.model_executor.forward_batch_info
import
Forward
Batch
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
...
@@ -686,21 +686,21 @@ class DeepEPDispatcher:
...
@@ -686,21 +686,21 @@ class DeepEPDispatcher:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_
mode
:
Forward
Mode
=
None
,
forward_
batch
:
Forward
Batch
,
):
):
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
inner_state
=
self
.
_get_impl
(
forward_
mode
).
dispatch_a
(
inner_state
=
self
.
_get_impl
(
forward_
batch
).
dispatch_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
)
)
self
.
_dispatch_intermediate_state
=
forward_
mode
,
inner_state
self
.
_dispatch_intermediate_state
=
forward_
batch
,
inner_state
def
dispatch_b
(
self
):
def
dispatch_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
forward_
mode
,
inner_state
=
self
.
_dispatch_intermediate_state
forward_
batch
,
inner_state
=
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
(
forward_
mode
).
dispatch_b
(
*
inner_state
)
return
self
.
_get_impl
(
forward_
batch
).
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
combine_a
(
*
args
,
**
kwargs
)
self
.
combine_a
(
*
args
,
**
kwargs
)
...
@@ -712,24 +712,26 @@ class DeepEPDispatcher:
...
@@ -712,24 +712,26 @@ class DeepEPDispatcher:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_
mode
:
Forward
Mode
,
forward_
batch
:
Forward
Batch
,
):
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
(
forward_
mode
).
combine_a
(
inner_state
=
self
.
_get_impl
(
forward_
batch
).
combine_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
)
)
self
.
_combine_intermediate_state
=
forward_
mode
,
inner_state
self
.
_combine_intermediate_state
=
forward_
batch
,
inner_state
def
combine_b
(
self
):
def
combine_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
forward_
mode
,
inner_state
=
self
.
_combine_intermediate_state
forward_
batch
,
inner_state
=
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_
mode
).
combine_b
(
*
inner_state
)
return
self
.
_get_impl
(
forward_
batch
).
combine_b
(
*
inner_state
)
def
_get_impl
(
self
,
forward_mode
:
ForwardMode
)
->
_DeepEPDispatcherImplBase
:
def
_get_impl
(
self
,
forward_batch
:
ForwardBatch
)
->
_DeepEPDispatcherImplBase
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_batch
.
is_extend_in_batch
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
_normal_dispatcher
return
self
.
_normal_dispatcher
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
8fc910db
...
@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
is_extend_in_batch
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
is_extend_in_batch
:
bool
=
False
is_extend_in_batch
:
bool
=
False
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_split_seq_index
:
Optional
[
int
]
=
None
...
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_ids_logprobs
=
self
.
token_ids_logprobs
,
token_ids_logprobs
=
self
.
token_ids_logprobs
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
is_extend_in_batch
=
self
.
is_extend_in_batch
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
tbo_split_seq_index
=
self
.
tbo_split_seq_index
,
tbo_split_seq_index
=
self
.
tbo_split_seq_index
,
global_forward_mode
=
self
.
global_forward_mode
,
global_forward_mode
=
self
.
global_forward_mode
,
...
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
...
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
is_extend_in_batch
:
bool
can_run_dp_cuda_graph
:
bool
can_run_dp_cuda_graph
:
bool
tbo_split_seq_index
:
Optional
[
int
]
tbo_split_seq_index
:
Optional
[
int
]
global_forward_mode
:
Optional
[
ForwardMode
]
global_forward_mode
:
Optional
[
ForwardMode
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
8fc910db
...
@@ -1490,7 +1490,7 @@ class Scheduler(
...
@@ -1490,7 +1490,7 @@ class Scheduler(
if
need_dp_attn_preparation
and
not
self
.
spec_algorithm
.
is_none
():
if
need_dp_attn_preparation
and
not
self
.
spec_algorithm
.
is_none
():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch
,
_
=
self
.
prepare_mlp_sync_batch
(
new_batch
)
new_batch
=
self
.
prepare_mlp_sync_batch
(
new_batch
)
need_dp_attn_preparation
=
new_batch
is
None
need_dp_attn_preparation
=
new_batch
is
None
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
...
@@ -1506,7 +1506,7 @@ class Scheduler(
...
@@ -1506,7 +1506,7 @@ class Scheduler(
# Handle DP attention
# Handle DP attention
if
need_dp_attn_preparation
:
if
need_dp_attn_preparation
:
ret
,
_
=
self
.
prepare_mlp_sync_batch
(
ret
)
ret
=
self
.
prepare_mlp_sync_batch
(
ret
)
return
ret
return
ret
...
@@ -1923,8 +1923,7 @@ class Scheduler(
...
@@ -1923,8 +1923,7 @@ class Scheduler(
if
not
disable_cuda_graph
:
if
not
disable_cuda_graph
:
local_batch
.
can_run_dp_cuda_graph
=
can_cuda_graph
local_batch
.
can_run_dp_cuda_graph
=
can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return
local_batch
return
local_batch
,
any
(
is_extend_in_batch
)
def
get_idle_batch
(
self
):
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
idle_batch
=
ScheduleBatch
.
init_new
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
8fc910db
...
@@ -254,6 +254,7 @@ class ForwardBatch:
...
@@ -254,6 +254,7 @@ class ForwardBatch:
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
is_extend_in_batch
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
...
@@ -299,6 +300,7 @@ class ForwardBatch:
...
@@ -299,6 +300,7 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
is_extend_in_batch
=
batch
.
is_extend_in_batch
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
global_forward_mode
=
batch
.
global_forward_mode
,
global_forward_mode
=
batch
.
global_forward_mode
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8fc910db
...
@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
...
@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
masked_m
=
masked_m
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
hidden_states
=
final_hidden_states
,
hidden_states
=
final_hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
...
@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
state
.
hidden_states_mlp_input
,
hidden_states
=
state
.
hidden_states_mlp_input
,
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
masked_m
=
state
.
pop
(
"masked_m"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
...
@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
and
hidden_states
.
shape
[
0
]
==
0
and
hidden_states
.
shape
[
0
]
==
0
):
):
state
.
hidden_states_mlp_output
=
self
.
mlp
(
state
.
hidden_states_mlp_output
=
self
.
mlp
(
hidden_states
,
state
.
forward_batch
.
forward_mode
hidden_states
,
state
.
forward_batch
)
)
else
:
else
:
state
.
hidden_states_mlp_output
=
hidden_states
state
.
hidden_states_mlp_output
=
hidden_states
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
8fc910db
...
@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m
=
masked_m
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
hidden_states
=
final_hidden_states
,
hidden_states
=
final_hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_
mode
=
forward_
mode
,
forward_
batch
=
forward_
batch
,
)
)
return
final_hidden_states
return
final_hidden_states
...
@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
masked_m
=
state
.
pop
(
"masked_m"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
...
@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_
mode
=
state
.
forward_batch
.
forward_mode
,
forward_
batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
def
op_mlp
(
self
,
state
):
def
op_mlp
(
self
,
state
):
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
)
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
)
state
.
hidden_states_mlp_output
=
self
.
mlp
(
state
.
hidden_states_mlp_output
=
self
.
mlp
(
hidden_states
,
state
.
forward_batch
)
hidden_states
,
state
.
forward_batch
.
forward_mode
)
def
op_comm_postprocess_layer
(
self
,
state
):
def
op_comm_postprocess_layer
(
self
,
state
):
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
...
...
python/sglang/srt/server_args.py
View file @
8fc910db
...
@@ -418,10 +418,6 @@ class ServerArgs:
...
@@ -418,10 +418,6 @@ class ServerArgs:
# DeepEP MoE
# DeepEP MoE
if
self
.
enable_deepep_moe
:
if
self
.
enable_deepep_moe
:
if
self
.
deepep_mode
==
"auto"
:
assert
(
not
self
.
enable_dp_attention
),
"DeepEP MoE `auto` mode is not supported with DP Attention."
if
self
.
deepep_mode
==
"normal"
:
if
self
.
deepep_mode
==
"normal"
:
logger
.
warning
(
"Cuda graph is disabled because deepep_mode=`normal`"
)
logger
.
warning
(
"Cuda graph is disabled because deepep_mode=`normal`"
)
self
.
disable_cuda_graph
=
True
self
.
disable_cuda_graph
=
True
...
...
python/sglang/srt/two_batch_overlap.py
View file @
8fc910db
...
@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
...
@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
)
)
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.operations_strategy
import
OperationsStrategy
...
@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
...
@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
class
TboDPAttentionPreparer
:
class
TboDPAttentionPreparer
:
def
prepare_all_gather
(
def
prepare_all_gather
(
self
,
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
self
,
local_batch
:
ScheduleBatch
,
deepep_mode
:
DeepEPMode
,
enable_deepep_moe
:
bool
,
enable_two_batch_overlap
:
bool
,
):
):
self
.
enable_two_batch_overlap
=
enable_two_batch_overlap
self
.
enable_two_batch_overlap
=
enable_two_batch_overlap
...
@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
...
@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
extend_lens
=
local_batch
.
extend_lens
,
extend_lens
=
local_batch
.
extend_lens
,
token_num_per_seq
=
token_num_per_seq
,
token_num_per_seq
=
token_num_per_seq
,
)
)
resolved_deepep_mode
=
deepep_mode
.
resolve
(
local_batch
.
forward_mode
)
resolved_deepep_mode
=
deepep_mode
.
resolve
(
local_batch
.
is_extend_in_batch
)
local_can_run_tbo
=
(
self
.
local_tbo_split_seq_index
is
not
None
)
and
not
(
local_can_run_tbo
=
(
self
.
local_tbo_split_seq_index
is
not
None
)
and
not
(
(
(
local_batch
.
forward_mode
.
is_extend
()
local_batch
.
forward_mode
.
is_extend
()
...
...
python/sglang/srt/utils.py
View file @
8fc910db
...
@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
...
@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
def
enable_low_latency
(
self
):
def
enable_low_latency
(
self
):
return
self
in
[
DeepEPMode
.
low_latency
,
DeepEPMode
.
auto
]
return
self
in
[
DeepEPMode
.
low_latency
,
DeepEPMode
.
auto
]
def
resolve
(
self
,
forward_mode
):
def
resolve
(
self
,
is_extend_in_batch
:
bool
):
if
self
!=
DeepEPMode
.
auto
:
if
self
!=
DeepEPMode
.
auto
:
return
self
return
self
if
forward_mode
.
is_decode
():
if
is_extend_in_batch
:
return
DeepEPMode
.
low_latency
else
:
return
DeepEPMode
.
normal
return
DeepEPMode
.
normal
else
:
return
DeepEPMode
.
low_latency
def
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
def
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
...
...
test/srt/test_hybrid_dp_ep_tp_mtp.py
View file @
8fc910db
...
@@ -539,8 +539,9 @@ class Test10(CustomTestCase):
...
@@ -539,8 +539,9 @@ class Test10(CustomTestCase):
"8"
,
"8"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -593,8 +594,9 @@ class Test11(CustomTestCase):
...
@@ -593,8 +594,9 @@ class Test11(CustomTestCase):
"4"
,
"4"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -647,8 +649,9 @@ class Test12(CustomTestCase):
...
@@ -647,8 +649,9 @@ class Test12(CustomTestCase):
"8"
,
"8"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -700,8 +703,9 @@ class Test13(CustomTestCase):
...
@@ -700,8 +703,9 @@ class Test13(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -756,8 +760,9 @@ class Test14(CustomTestCase):
...
@@ -756,8 +760,9 @@ class Test14(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -812,8 +817,9 @@ class Test15(CustomTestCase):
...
@@ -812,8 +817,9 @@ class Test15(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -867,8 +873,9 @@ class Test16(CustomTestCase):
...
@@ -867,8 +873,9 @@ class Test16(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -922,8 +929,9 @@ class Test17(CustomTestCase):
...
@@ -922,8 +929,9 @@ class Test17(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -979,8 +987,9 @@ class Test18(CustomTestCase):
...
@@ -979,8 +987,9 @@ class Test18(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase):
...
@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"128"
,
],
],
)
)
...
@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase):
...
@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase):
"8"
,
"8"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase):
...
@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase):
"4"
,
"4"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase):
...
@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase):
"8"
,
"8"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase):
...
@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase):
...
@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase):
...
@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase):
"1"
,
"1"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase):
...
@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase):
...
@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase):
...
@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase):
...
@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase):
"--enable-dp-lm-head"
,
"--enable-dp-lm-head"
,
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"--deepep-mode"
,
"normal"
,
"auto"
,
"--disable-cuda-graph"
,
"--cuda-graph-max-bs"
,
"32"
,
"--max-running-requests"
,
"32"
,
"--speculative-algo"
,
"--speculative-algo"
,
"NEXTN"
,
"NEXTN"
,
"--speculative-draft"
,
"--speculative-draft"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment