Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
6aca5834
"examples/instruct_pix2pix/train_instruct_pix2pix.py" did not exist on "f20c8f5a1aba27f5972cad50516f18ba516e4d9e"
Unverified
Commit
6aca5834
authored
Apr 15, 2025
by
Cheng Wan
Committed by
GitHub
Apr 15, 2025
Browse files
Fix several minor issues in PD disaggregation (#5444)
parent
5b5c7237
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
69 deletions
+67
-69
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+32
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+31
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-68
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
6aca5834
...
@@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -419,6 +419,38 @@ class ScheduleBatchDisaggregationDecodeMixin:
class
SchedulerDisaggregationDecodeMixin
:
class
SchedulerDisaggregationDecodeMixin
:
@
torch
.
no_grad
()
def
event_loop_normal_disagg_decode
(
self
):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
else
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
def
get_next_disagg_decode_batch_to_run
(
def
get_next_disagg_decode_batch_to_run
(
self
:
Scheduler
,
self
:
Scheduler
,
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
)
->
Optional
[
Tuple
[
ScheduleBatch
,
bool
]]:
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
6aca5834
...
@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
Mixin for Scheduler to handle disaggregation prefill
Mixin for Scheduler to handle disaggregation prefill
"""
"""
@
torch
.
no_grad
()
def
event_loop_normal_disagg_prefill
(
self
):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
self
.
disagg_prefill_pending_queue
.
pop_bootstrapped
()
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result_disagg_prefill
(
batch
,
result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self
.
running_batch
.
batch_is_full
=
False
def
process_batch_result_disagg_prefill
(
def
process_batch_result_disagg_prefill
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
)
->
None
:
)
->
None
:
...
@@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -210,7 +240,7 @@ class SchedulerDisaggregationPrefillMixin:
polls
=
poll_and_all_reduce
(
polls
=
poll_and_all_reduce
(
[
req
.
disagg_kv_sender
for
req
in
self
.
disagg_prefill_inflight_queue
],
[
req
.
disagg_kv_sender
for
req
in
self
.
disagg_prefill_inflight_queue
],
self
.
tp_worker
.
get
_tp_cpu_group
()
,
self
.
attn
_tp_cpu_group
,
)
)
undone_reqs
:
List
[
Req
]
=
[]
undone_reqs
:
List
[
Req
]
=
[]
...
...
python/sglang/srt/managers/scheduler.py
View file @
6aca5834
...
@@ -484,7 +484,7 @@ class Scheduler(
...
@@ -484,7 +484,7 @@ class Scheduler(
self
.
tree_cache
=
HiRadixCache
(
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_
worker
.
get_tp_
cpu_group
()
,
tp_cache_group
=
self
.
tp_cpu_group
,
page_size
=
self
.
page_size
,
page_size
=
self
.
page_size
,
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_ratio
=
server_args
.
hicache_ratio
,
)
)
...
@@ -553,7 +553,7 @@ class Scheduler(
...
@@ -553,7 +553,7 @@ class Scheduler(
# The decode requests polling kv cache
# The decode requests polling kv cache
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
metadata_buffers
,
)
)
...
@@ -568,7 +568,7 @@ class Scheduler(
...
@@ -568,7 +568,7 @@ class Scheduler(
scheduler
=
self
,
scheduler
=
self
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
tree_cache
=
self
.
tree_cache
,
tree_cache
=
self
.
tree_cache
,
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
...
@@ -597,7 +597,7 @@ class Scheduler(
...
@@ -597,7 +597,7 @@ class Scheduler(
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
gloo_group
=
self
.
tp_worker
.
get_attentio
n_tp_cpu_group
()
,
gloo_group
=
self
.
att
n_tp_cpu_group
,
transfer_backend
=
self
.
transfer_backend
,
transfer_backend
=
self
.
transfer_backend
,
scheduler
=
self
,
scheduler
=
self
,
)
)
...
@@ -664,70 +664,6 @@ class Scheduler(
...
@@ -664,70 +664,6 @@ class Scheduler(
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
torch
.
no_grad
()
def
event_loop_normal_disagg_prefill
(
self
):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
self
.
disagg_prefill_pending_queue
.
pop_bootstrapped
()
)
self
.
process_prefill_chunk
()
batch
=
self
.
get_new_batch_prefill
()
self
.
cur_batch
=
batch
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result_disagg_prefill
(
batch
,
result
)
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self
.
running_batch
.
batch_is_full
=
False
@
torch
.
no_grad
()
def
event_loop_normal_disagg_decode
(
self
):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while
True
:
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
# polling and allocating kv cache
self
.
process_decode_queue
()
batch
=
self
.
get_next_disagg_decode_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
[
False
for
_
in
range
(
len
(
batch
.
reqs
))]
)
else
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
if
batch
is
None
and
(
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
==
0
):
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
def
recv_requests
(
self
)
->
List
[
Req
]:
def
recv_requests
(
self
)
->
List
[
Req
]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if
self
.
attn_tp_rank
==
0
:
if
self
.
attn_tp_rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment