Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
711efe78
Unverified
Commit
711efe78
authored
Apr 23, 2025
by
Cheng Wan
Committed by
GitHub
Apr 23, 2025
Browse files
Integrating PD disaggregation with DP attention and DeepEP (#5435)
Co-authored-by:
Byron Hsu
<
byronhsu1230@gmail.com
>
parent
fbb5f229
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
8 deletions
+72
-8
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+46
-5
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+16
-0
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+10
-3
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
711efe78
...
...
@@ -444,6 +444,15 @@ class ScheduleBatchDisaggregationDecodeMixin:
class
SchedulerDisaggregationDecodeMixin
:
def
_prepare_idle_batch_and_run
(
self
,
batch
,
delay_process
=
False
):
batch
,
_
=
self
.
prepare_dp_attn_batch
(
batch
)
result
=
None
if
batch
:
result
=
self
.
run_batch
(
batch
)
if
not
delay_process
:
self
.
process_batch_result
(
batch
,
result
)
return
batch
,
result
@
torch
.
no_grad
()
def
event_loop_normal_disagg_decode
(
self
):
"""A normal scheduler loop for decode worker in disaggregation mode."""
...
...
@@ -456,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
prepare_dp_attn_flag
=
(
self
.
server_args
.
enable_dp_attention
or
self
.
server_args
.
enable_sp_layernorm
)
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
if
prepare_dp_attn_flag
:
self
.
_prepare_idle_batch_and_run
(
None
)
else
:
if
prepare_dp_attn_flag
:
self
.
prepare_dp_attn_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
elif
prepare_dp_attn_flag
:
batch
,
_
=
self
.
_prepare_idle_batch_and_run
(
None
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
...
...
@@ -480,7 +500,7 @@ class SchedulerDisaggregationDecodeMixin:
def
event_loop_overlap_disagg_decode
(
self
):
result_queue
=
deque
()
self
.
last_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
last_batch_i
s_extend
=
False
# last batch is modifed in-place, so we need another variable to track if it's extend
self
.
last_batch_i
n_queue
=
False
# last batch is modifed in-place, so we need another variable to track if it's extend
while
True
:
recv_reqs
=
self
.
recv_requests
()
...
...
@@ -489,20 +509,41 @@ class SchedulerDisaggregationDecodeMixin:
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
last_batch_is_extend
=
False
last_batch_in_queue
=
False
prepare_dp_attn_flag
=
(
self
.
server_args
.
enable_dp_attention
or
self
.
server_args
.
enable_sp_layernorm
)
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
last_batch_is_extend
=
True
if
prepare_dp_attn_flag
:
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
)
if
batch_
:
result_queue
.
append
((
batch_
.
copy
(),
result
))
last_batch_in_queue
=
True
else
:
if
prepare_dp_attn_flag
:
self
.
prepare_dp_attn_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
last_batch_in_queue
=
True
elif
prepare_dp_attn_flag
:
batch
,
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
)
if
batch
:
result_queue
.
append
((
batch
.
copy
(),
result
))
last_batch_in_queue
=
True
# Process the results of the previous batch but skip if the last batch is extend
if
self
.
last_batch
and
not
self
.
last_batch_i
s_extend
:
if
self
.
last_batch
and
self
.
last_batch_i
n_queue
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
...
...
@@ -516,7 +557,7 @@ class SchedulerDisaggregationDecodeMixin:
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch_i
s_extend
=
last_batch_i
s_extend
self
.
last_batch_i
n_queue
=
last_batch_i
n_queue
def
get_next_disagg_decode_batch_to_run
(
self
:
Scheduler
,
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
711efe78
...
...
@@ -187,6 +187,14 @@ class SchedulerDisaggregationPrefillMixin:
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
# Handle DP attention
if
(
self
.
server_args
.
enable_dp_attention
or
self
.
server_args
.
enable_sp_layernorm
):
batch
,
_
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
if
batch
:
...
...
@@ -217,6 +225,14 @@ class SchedulerDisaggregationPrefillMixin:
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
# Handle DP attention
if
(
self
.
server_args
.
enable_dp_attention
or
self
.
server_args
.
enable_sp_layernorm
):
batch
,
_
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
if
batch
:
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
711efe78
...
...
@@ -23,11 +23,13 @@ import psutil
import
setproctitle
import
zmq
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.managers.io_struct
import
(
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
@@ -226,9 +228,14 @@ class DataParallelController:
self
.
max_total_num_tokens
=
scheduler_info
[
0
][
"max_total_num_tokens"
]
self
.
max_req_input_len
=
scheduler_info
[
0
][
"max_req_input_len"
]
def
round_robin_scheduler
(
self
,
req
):
def
round_robin_scheduler
(
self
,
req
:
Req
):
if
self
.
server_args
.
disaggregation_mode
==
"null"
:
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
else
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
def
shortest_queue_scheduler
(
self
,
input_requests
):
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment