Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bfadb5ea
Unverified
Commit
bfadb5ea
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Adjust overlap event loop (#11507)
parent
9cc1e065
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
45 deletions
+44
-45
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+9
-7
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+16
-28
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+13
-3
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
bfadb5ea
...
@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin:
self
.
last_batch_in_queue
=
False
# last batch is modified in-place, so we need another variable to track if it's extend
self
.
last_batch_in_queue
=
False
# last batch is modified in-place, so we need another variable to track if it's extend
while
True
:
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
...
@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin:
prepare_mlp_sync_flag
=
require_mlp_sync
(
self
.
server_args
)
prepare_mlp_sync_flag
=
require_mlp_sync
(
self
.
server_args
)
batch_result
=
None
if
batch
:
if
batch
:
# Generate fake extend output.
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
if
batch
.
forward_mode
.
is_extend
():
...
@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin:
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
)
if
prepare_mlp_sync_flag
:
if
prepare_mlp_sync_flag
:
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
batch_
,
batch_
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
None
,
delay_process
=
True
)
)
if
batch_
:
if
batch_
:
self
.
result_queue
.
append
((
batch_
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch_
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
last_batch_in_queue
=
True
else
:
else
:
if
prepare_mlp_sync_flag
:
if
prepare_mlp_sync_flag
:
self
.
prepare_mlp_sync_batch
(
batch
)
self
.
prepare_mlp_sync_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
last_batch_in_queue
=
True
elif
prepare_mlp_sync_flag
:
elif
prepare_mlp_sync_flag
:
batch
,
result
=
self
.
_prepare_idle_batch_and_run
(
batch
,
batch_
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
None
,
delay_process
=
True
)
)
if
batch
:
if
batch
:
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
last_batch_in_queue
=
True
# Process the results of the previous batch but skip if the last batch is extend
# Process the results of the previous batch but skip if the last batch is extend
...
@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
launch_batch_sample_if_needed
(
batch_result
)
queue_size
=
(
queue_size
=
(
len
(
self
.
waiting_queue
)
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
bfadb5ea
...
@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin:
self
.
result_queue
=
deque
()
self
.
result_queue
=
deque
()
while
True
:
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
self
.
waiting_queue
.
extend
(
...
@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin:
if
require_mlp_sync
(
self
.
server_args
):
if
require_mlp_sync
(
self
.
server_args
):
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
batch_result
=
None
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
if
self
.
last_batch
:
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
...
@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin:
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
self
.
process_disagg_prefill_inflight_queue
()
self
.
launch_batch_sample_if_needed
(
batch_result
)
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
bfadb5ea
...
@@ -1907,8 +1907,5 @@ class ModelWorkerBatch:
...
@@ -1907,8 +1907,5 @@ class ModelWorkerBatch:
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
hicache_consumer_index
:
int
=
-
1
# Overlap scheduler related
delay_sample_launch
:
bool
=
False
# Whether this batch is prefill-only (no token generation needed)
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
is_prefill_only
:
bool
=
False
python/sglang/srt/managers/scheduler.py
View file @
bfadb5ea
...
@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
...
@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
PPProxyTensors
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
,
get_global_server_args
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
...
@@ -212,8 +212,7 @@ class GenerationBatchResult:
...
@@ -212,8 +212,7 @@ class GenerationBatchResult:
# For overlap scheduling
# For overlap scheduling
copy_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
copy_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
delay_sample_launch
:
bool
=
False
delay_sample_func
:
Optional
[
callable
]
=
None
forward_batch
:
Optional
[
ForwardBatch
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
# FIXME(lsyin): maybe move to <BetterPlace> ?
# FIXME(lsyin): maybe move to <BetterPlace> ?
...
@@ -1036,17 +1035,16 @@ class Scheduler(
...
@@ -1036,17 +1035,16 @@ class Scheduler(
self
.
result_queue
:
Deque
[
Tuple
[
ScheduleBatch
,
GenerationBatchResult
]]
=
deque
()
self
.
result_queue
:
Deque
[
Tuple
[
ScheduleBatch
,
GenerationBatchResult
]]
=
deque
()
while
True
:
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
batch_result
=
None
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
if
self
.
last_batch
:
if
self
.
last_batch
:
# Process the results of the last batch
# Process the results of the last batch
...
@@ -1056,6 +1054,7 @@ class Scheduler(
...
@@ -1056,6 +1054,7 @@ class Scheduler(
# When the server is idle, do self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self
.
self_check_during_idle
()
self
.
self_check_during_idle
()
self
.
launch_batch_sample_if_needed
(
batch_result
)
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
DynamicGradMode
()
@
DynamicGradMode
()
...
@@ -2207,8 +2206,6 @@ class Scheduler(
...
@@ -2207,8 +2206,6 @@ class Scheduler(
with
self
.
forward_stream_ctx
:
with
self
.
forward_stream_ctx
:
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
self
.
future_map
.
resolve_future
(
model_worker_batch
)
self
.
future_map
.
resolve_future
(
model_worker_batch
)
if
batch
.
sampling_info
.
grammars
is
not
None
:
model_worker_batch
.
delay_sample_launch
=
True
batch_result
=
self
.
model_worker
.
forward_batch_generation
(
batch_result
=
self
.
model_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
...
@@ -2216,7 +2213,7 @@ class Scheduler(
...
@@ -2216,7 +2213,7 @@ class Scheduler(
batch_result
.
copy_done
=
torch
.
get_device_module
(
batch_result
.
copy_done
=
torch
.
get_device_module
(
self
.
device
self
.
device
).
Event
()
).
Event
()
if
not
model_worker_batch
.
delay_sample_
la
unc
h
:
if
batch_result
.
delay_sample_
f
unc
is
None
:
self
.
future_map
.
store_to_map
(
future_indices
,
batch_result
)
self
.
future_map
.
store_to_map
(
future_indices
,
batch_result
)
batch_result
.
copy_to_cpu
()
batch_result
.
copy_to_cpu
()
else
:
else
:
...
@@ -2280,29 +2277,20 @@ class Scheduler(
...
@@ -2280,29 +2277,20 @@ class Scheduler(
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
)
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
)
return
ret
return
ret
def
launch_
last_
batch_sample_if_needed
(
def
launch_batch_sample_if_needed
(
self
,
self
,
batch_result
:
GenerationBatchResult
)
->
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
]:
)
->
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
]:
if
len
(
self
.
result_queue
)
==
0
:
# TODO(lsyin): make the delayed sample a default behavior after
return
# unifying the forward_batch_generation interface (related to spec V2).
if
batch_result
is
None
or
batch_result
.
delay_sample_func
is
None
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_result
:
GenerationBatchResult
if
not
tmp_result
.
delay_sample_launch
:
self
.
result_queue
.
appendleft
((
tmp_batch
,
tmp_result
))
return
return
with
self
.
forward_stream_ctx
:
with
self
.
forward_stream_ctx
:
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
tmp_result
.
next_token_ids
=
self
.
model_worker
.
model_runner
.
sample
(
_batch_result
=
batch_result
.
delay_sample_func
()
tmp_result
.
logits_output
,
assert
_batch_result
is
batch_result
tmp_result
.
forward_batch
,
self
.
future_map
.
store_to_map
(
batch_result
.
future_indices
,
batch_result
)
)
batch_result
.
copy_to_cpu
()
future_indices
=
tmp_result
.
future_indices
self
.
future_map
.
store_to_map
(
future_indices
,
tmp_result
)
tmp_result
.
copy_to_cpu
()
self
.
result_queue
.
appendleft
((
tmp_batch
,
tmp_result
))
def
process_batch_result
(
def
process_batch_result
(
self
,
self
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
bfadb5ea
...
@@ -168,6 +168,7 @@ class TpModelWorker:
...
@@ -168,6 +168,7 @@ class TpModelWorker:
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
hicache_layer_transfer_counter
=
None
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
...
@@ -266,9 +267,18 @@ class TpModelWorker:
...
@@ -266,9 +267,18 @@ class TpModelWorker:
# Skip sampling and return logits for target forward
# Skip sampling and return logits for target forward
return
batch_result
return
batch_result
if
model_worker_batch
.
delay_sample_launch
:
if
(
batch_result
.
delay_sample_launch
=
True
self
.
enable_overlap
batch_result
.
forward_batch
=
forward_batch
and
model_worker_batch
.
sampling_info
.
grammars
is
not
None
):
def
sample_batch_func
():
batch_result
.
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
batch_result
batch_result
.
delay_sample_func
=
sample_batch_func
return
batch_result
return
batch_result
if
model_worker_batch
.
is_prefill_only
:
if
model_worker_batch
.
is_prefill_only
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment