Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e65b9f21
Unverified
Commit
e65b9f21
authored
Apr 21, 2025
by
Byron Hsu
Committed by
GitHub
Apr 21, 2025
Browse files
[PD] Support decode overlap schedule (#5608)
parent
4dce1cc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
5 deletions
+49
-5
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+43
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-4
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
e65b9f21
...
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
...
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
...
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
torch
.
no_grad
()
def
event_loop_overlap_disagg_decode
(
self
):
result_queue
=
deque
()
self
.
last_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
last_batch_is_extend
=
False
# last batch is modifed in-place, so we need another variable to track if it's extend
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
last_batch_is_extend
=
False
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
last_batch_is_extend
=
True
else
:
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
# Process the results of the previous batch but skip if the last batch is extend
if
self
.
last_batch
and
not
self
.
last_batch_is_extend
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch_is_extend
=
last_batch_is_extend
def
get_next_disagg_decode_batch_to_run
(
def
get_next_disagg_decode_batch_to_run
(
self
:
Scheduler
,
self
:
Scheduler
,
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
...
...
python/sglang/srt/managers/scheduler.py
View file @
e65b9f21
...
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
...
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
elif
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
elif
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
scheduler
.
event_loop_normal_disagg_prefill
()
scheduler
.
event_loop_normal_disagg_prefill
()
elif
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
elif
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
scheduler
.
event_loop_normal_disagg_decode
()
if
scheduler
.
enable_overlap
:
scheduler
.
event_loop_overlap_disagg_decode
()
else
:
scheduler
.
event_loop_normal_disagg_decode
()
except
Exception
:
except
Exception
:
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
...
...
python/sglang/srt/server_args.py
View file @
e65b9f21
...
@@ -387,14 +387,12 @@ class ServerArgs:
...
@@ -387,14 +387,12 @@ class ServerArgs:
# PD disaggregation
# PD disaggregation
if
self
.
disaggregation_mode
==
"prefill"
:
if
self
.
disaggregation_mode
==
"prefill"
:
self
.
disable_cuda_graph
=
True
self
.
disable_cuda_graph
=
True
logger
.
warning
(
"
KV cache is forced as chunk cache for decode
server"
)
logger
.
warning
(
"
Cuda graph is disabled for prefill
server"
)
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Overlap scheduler is disabled for prefill server"
)
logger
.
warning
(
"Overlap scheduler is disabled for prefill server"
)
elif
self
.
disaggregation_mode
==
"decode"
:
elif
self
.
disaggregation_mode
==
"decode"
:
self
.
disable_radix_cache
=
True
self
.
disable_radix_cache
=
True
logger
.
warning
(
"Cuda graph is disabled for prefill server"
)
logger
.
warning
(
"KV cache is forced as chunk cache for decode server"
)
self
.
disable_overlap_schedule
=
True
logger
.
warning
(
"Overlap scheduler is disabled for decode server"
)
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
os
.
environ
[
"SGLANG_ENABLE_TORCH_COMPILE"
]
=
(
"1"
if
self
.
enable_torch_compile
else
"0"
"1"
if
self
.
enable_torch_compile
else
"0"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment