Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6aca5834
Unverified
Commit
6aca5834
authored
Apr 15, 2025
by
Cheng Wan
Committed by
GitHub
Apr 15, 2025
Browse files
Fix several minor issues in PD disaggregation (#5444)
parent
5b5c7237
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
69 deletions
+67
-69
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+32
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+31
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-68
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
6aca5834
...
...
@@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
class
SchedulerDisaggregationDecodeMixin
:
@
torch
.
no_grad
()
def
event_loop_normal_disagg_decode
(
self
):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
else
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
def
get_next_disagg_decode_batch_to_run
(
self
:
Scheduler
,
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
6aca5834
...
...
@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
Mixin for Scheduler to handle disaggregation prefill
"""
@
torch
.
no_grad
()
def
event_loop_normal_disagg_prefill
(
self
):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
self
.
disagg_prefill_pending_queue
.
pop_bootstrapped
()
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result_disagg_prefill
(
batch
,
result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self
.
running_batch
.
batch_is_full
=
False
def
process_batch_result_disagg_prefill
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
)
->
None
:
...
...
@@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin:
polls
=
poll_and_all_reduce
(
[
req
.
disagg_kv_sender
for
req
in
self
.
disagg_prefill_inflight_queue
],
self
.
tp_worker
.
get
_tp_cpu_group
()
,
self
.
attn
_tp_cpu_group
,
)
undone_reqs
:
List
[
Req
]
=
[]
...
...
python/sglang/srt/managers/scheduler.py
View file @
6aca5834
...
...
@@ -484,7 +484,7 @@ class Scheduler(
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_
worker
.
get_tp_
cpu_group
()
,
tp_cache_group
=
self
.
tp_cpu_group
,
page_size
=
self
.
page_size
,
hicache_ratio
=
server_args
.
hicache_ratio
,
)
...
...
@@ -553,7 +553,7 @@ class Scheduler(
# The decode requests polling kv cache
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
)
...
...
@@ -568,7 +568,7 @@ class Scheduler(
scheduler
=
self
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
tree_cache
=
self
.
tree_cache
,
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
...
...
@@ -597,7 +597,7 @@ class Scheduler(
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
transfer_backend
=
self
.
transfer_backend
,
scheduler
=
self
,
)
...
...
@@ -664,70 +664,6 @@ class Scheduler(
self
.
last_batch
=
batch
@
torch
.
no_grad
()
def
event_loop_normal_disagg_prefill
(
self
):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
self
.
disagg_prefill_pending_queue
.
pop_bootstrapped
()
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result_disagg_prefill
(
batch
,
result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self
.
running_batch
.
batch_is_full
=
False
@
torch
.
no_grad
()
def
event_loop_normal_disagg_decode
(
self
):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
[
False
for
_
in
range
(
len
(
batch
.
reqs
))]
)
else
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
def
recv_requests
(
self
)
->
List
[
Req
]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if
self
.
attn_tp_rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment