Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bfadb5ea
Unverified
Commit
bfadb5ea
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Adjust overlap event loop (#11507)
parent
9cc1e065
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
45 deletions
+44
-45
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+9
-7
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+16
-28
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+13
-3
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
bfadb5ea
...
...
@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin:
self
.
last_batch_in_queue
=
False
# last batch is modified in-place, so we need another variable to track if it's extend
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
...
...
@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin:
prepare_mlp_sync_flag
=
require_mlp_sync
(
self
.
server_args
)
batch_result
=
None
if
batch
:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
...
...
@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin:
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
if
prepare_mlp_sync_flag
:
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
batch_
,
batch_
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
)
if
batch_
:
self
.
result_queue
.
append
((
batch_
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch_
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
else
:
if
prepare_mlp_sync_flag
:
self
.
prepare_mlp_sync_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
elif
prepare_mlp_sync_flag
:
batch
,
result
=
self
.
_prepare_idle_batch_and_run
(
batch
,
batch_
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
)
if
batch
:
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
last_batch_in_queue
=
True
# Process the results of the previous batch but skip if the last batch is extend
...
...
@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
launch_batch_sample_if_needed
(
batch_result
)
queue_size
=
(
len
(
self
.
waiting_queue
)
+
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
bfadb5ea
...
...
@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin:
self
.
result_queue
=
deque
()
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
waiting_queue
.
extend
(
...
...
@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin:
if
require_mlp_sync
(
self
.
server_args
):
batch
=
self
.
prepare_mlp_sync_batch
(
batch
)
self
.
cur_batch
=
batch
batch_result
=
None
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
...
...
@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin:
if
len
(
self
.
disagg_prefill_inflight_queue
)
>
0
:
self
.
process_disagg_prefill_inflight_queue
()
self
.
launch_batch_sample_if_needed
(
batch_result
)
if
batch
is
None
and
len
(
self
.
disagg_prefill_inflight_queue
)
==
0
:
self
.
self_check_during_idle
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
bfadb5ea
...
...
@@ -1907,8 +1907,5 @@ class ModelWorkerBatch:
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
# Overlap scheduler related
delay_sample_launch
:
bool
=
False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
python/sglang/srt/managers/scheduler.py
View file @
bfadb5ea
...
...
@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
PPProxyTensors
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
,
get_global_server_args
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
...
...
@@ -212,8 +212,7 @@ class GenerationBatchResult:
# For overlap scheduling
copy_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
delay_sample_launch
:
bool
=
False
forward_batch
:
Optional
[
ForwardBatch
]
=
None
delay_sample_func
:
Optional
[
callable
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
# FIXME(lsyin): maybe move to <BetterPlace> ?
...
...
@@ -1036,17 +1035,16 @@ class Scheduler(
self
.
result_queue
:
Deque
[
Tuple
[
ScheduleBatch
,
GenerationBatchResult
]]
=
deque
()
while
True
:
self
.
launch_last_batch_sample_if_needed
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
batch_result
=
None
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
batch_
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
batch_
result
))
if
self
.
last_batch
:
# Process the results of the last batch
...
...
@@ -1056,6 +1054,7 @@ class Scheduler(
# When the server is idle, do self-check and re-init some states
self
.
self_check_during_idle
()
self
.
launch_batch_sample_if_needed
(
batch_result
)
self
.
last_batch
=
batch
@
DynamicGradMode
()
...
...
@@ -2207,8 +2206,6 @@ class Scheduler(
with
self
.
forward_stream_ctx
:
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
self
.
future_map
.
resolve_future
(
model_worker_batch
)
if
batch
.
sampling_info
.
grammars
is
not
None
:
model_worker_batch
.
delay_sample_launch
=
True
batch_result
=
self
.
model_worker
.
forward_batch_generation
(
model_worker_batch
)
...
...
@@ -2216,7 +2213,7 @@ class Scheduler(
batch_result
.
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
if
not
model_worker_batch
.
delay_sample_
la
unc
h
:
if
batch_result
.
delay_sample_
f
unc
is
None
:
self
.
future_map
.
store_to_map
(
future_indices
,
batch_result
)
batch_result
.
copy_to_cpu
()
else
:
...
...
@@ -2280,29 +2277,20 @@ class Scheduler(
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
)
return
ret
def
launch_
last_
batch_sample_if_needed
(
self
,
def
launch_batch_sample_if_needed
(
self
,
batch_result
:
GenerationBatchResult
)
->
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
]:
if
len
(
self
.
result_queue
)
==
0
:
return
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
tmp_result
:
GenerationBatchResult
if
not
tmp_result
.
delay_sample_launch
:
self
.
result_queue
.
appendleft
((
tmp_batch
,
tmp_result
))
# TODO(lsyin): make the delayed sample a default behavior after
# unifying the forward_batch_generation interface (related to spec V2).
if
batch_result
is
None
or
batch_result
.
delay_sample_func
is
None
:
return
with
self
.
forward_stream_ctx
:
self
.
forward_stream
.
wait_stream
(
self
.
default_stream
)
tmp_result
.
next_token_ids
=
self
.
model_worker
.
model_runner
.
sample
(
tmp_result
.
logits_output
,
tmp_result
.
forward_batch
,
)
future_indices
=
tmp_result
.
future_indices
self
.
future_map
.
store_to_map
(
future_indices
,
tmp_result
)
tmp_result
.
copy_to_cpu
()
self
.
result_queue
.
appendleft
((
tmp_batch
,
tmp_result
))
_batch_result
=
batch_result
.
delay_sample_func
()
assert
_batch_result
is
batch_result
self
.
future_map
.
store_to_map
(
batch_result
.
future_indices
,
batch_result
)
batch_result
.
copy_to_cpu
()
def
process_batch_result
(
self
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
bfadb5ea
...
...
@@ -168,6 +168,7 @@ class TpModelWorker:
)[
0
]
set_random_seed
(
self
.
random_seed
)
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
...
...
@@ -266,9 +267,18 @@ class TpModelWorker:
# Skip sampling and return logits for target forward
return
batch_result
if
model_worker_batch
.
delay_sample_launch
:
batch_result
.
delay_sample_launch
=
True
batch_result
.
forward_batch
=
forward_batch
if
(
self
.
enable_overlap
and
model_worker_batch
.
sampling_info
.
grammars
is
not
None
):
def
sample_batch_func
():
batch_result
.
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
batch_result
batch_result
.
delay_sample_func
=
sample_batch_func
return
batch_result
if
model_worker_batch
.
is_prefill_only
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment