Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e65b9f21
Unverified
Commit
e65b9f21
authored
Apr 21, 2025
by
Byron Hsu
Committed by
GitHub
Apr 21, 2025
Browse files
[PD] Support decode overlap schedule (#5608)
parent
4dce1cc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
5 deletions
+49
-5
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+43
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-4
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
e65b9f21
...
...
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from
__future__
import
annotations
import
logging
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
...
...
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
self
.
last_batch
=
batch
@
torch
.
no_grad
()
def
event_loop_overlap_disagg_decode
(
self
):
result_queue
=
deque
()
self
.
last_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
last_batch_is_extend
=
False
# last batch is modifed in-place, so we need another variable to track if it's extend
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
last_batch_is_extend
=
False
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
last_batch_is_extend
=
True
else
:
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
# Process the results of the previous batch but skip if the last batch is extend
if
self
.
last_batch
and
not
self
.
last_batch_is_extend
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch_is_extend
=
last_batch_is_extend
def
get_next_disagg_decode_batch_to_run
(
self
:
Scheduler
,
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
...
...
python/sglang/srt/managers/scheduler.py
View file @
e65b9f21
...
...
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
elif
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
scheduler
.
event_loop_normal_disagg_prefill
()
elif
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
scheduler
.
event_loop_normal_disagg_decode
()
if
scheduler
.
enable_overlap
:
scheduler
.
event_loop_overlap_disagg_decode
()
else
:
scheduler
.
event_loop_normal_disagg_decode
()
except
Exception
:
traceback
=
get_exception_traceback
()
...
...
python/sglang/srt/server_args.py
View file @
e65b9f21
...
...
@@ -387,14 +387,12 @@ class ServerArgs:
# PD disaggregation
if
self
.
disaggregation_mode
==
"prefill"
:
self
.
disable_cuda_graph
=
True
logger
.
warning
(
"
KV cache is forced as chunk cache for decode
server"
)
logger
.
warning
(
"
Cuda graph is disabled for prefill
server"
)
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Overlap scheduler is disabled for prefill server"
)
elif
self
.
disaggregation_mode
==
"decode"
:
self
.
disable_radix_cache
=
True
logger
.
warning
(
"Cuda graph is disabled for prefill server"
)
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Overlap scheduler is disabled for decode server"
)
logger
.
warning
(
"KV cache is forced as chunk cache for decode server"
)
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
"1"
if
self
.
enable_torch_compile
else
"0"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment