Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
40d9b8ac
Unverified
Commit
40d9b8ac
authored
Apr 28, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 28, 2025
Browse files
Improve overlap scheduling (#5788)
parent
f0365820
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
23 deletions
+61
-23
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-5
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+25
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+9
-4
No files found.
python/sglang/srt/disaggregation/prefill.py
View file @
40d9b8ac
...
...
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from
__future__
import
annotations
import
logging
import
threading
from
collections
import
deque
from
typing
import
TYPE_CHECKING
,
List
,
Optional
...
...
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_batch
.
batch_is_full
=
False
def
process_batch_result_disagg_prefill
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
...
...
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
# wait
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
40d9b8ac
...
...
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import
copy
import
dataclasses
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full
:
bool
=
False
# Events
launch_done
:
Optional
[
threading
.
Event
]
=
None
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
),
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
launch_done
=
self
.
launch_done
,
)
def
copy
(
self
):
...
...
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
@
triton
.
jit
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
40d9b8ac
...
...
@@ -645,6 +645,7 @@ class Scheduler(
self
.
cur_batch
=
batch
if
batch
:
batch
.
launch_done
=
threading
.
Event
()
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
...
...
@@ -656,7 +657,7 @@ class Scheduler(
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
process_batch_result
(
tmp_batch
,
None
)
self
.
process_batch_result
(
tmp_batch
,
None
,
batch
.
launch_done
)
if
self
.
last_batch
:
# Process the results of the last batch
...
...
@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self
.
process_batch_result
(
tmp_batch
,
tmp_result
,
batch
.
launch_done
if
batch
else
None
)
elif
batch
is
None
:
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
...
...
@@ -1417,14 +1421,15 @@ class Scheduler(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_extend
():
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_batch_result
(
result
.
bid
)
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
40d9b8ac
from
__future__
import
annotations
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
...
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
EmbeddingBatchResult
,
GenerationBatchResult
,
ScheduleBatch
,
Scheduler
,
)
...
...
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
"""
def
process_batch_result_prefill
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
skip_stream_req
=
None
...
...
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
,
)
)
else
:
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
...
...
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
...
...
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
...
...
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
self
.
log_decode_stats
()
def
add_input_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
req
:
Req
,
output
:
LogitsProcessorOutput
,
...
...
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
def
add_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
req
:
Req
,
pt
:
int
,
...
...
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
"""Stream the output to detokenizer."""
if
self
.
is_generation
:
...
...
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output_embedding
(
reqs
)
def
stream_output_generation
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
...
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
)
)
def
stream_output_embedding
(
self
,
reqs
:
List
[
Req
]):
def
stream_output_embedding
(
self
:
Scheduler
,
reqs
:
List
[
Req
]):
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
...
python/sglang/srt/managers/tp_worker.py
View file @
40d9b8ac
...
...
@@ -170,13 +170,13 @@ class TpModelWorker:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
launch_done
:
launch_done
.
set
()
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
skip_sample
:
next_token_ids
=
None
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
40d9b8ac
...
...
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
batch_pt
+=
1
# Create event
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
...
...
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
self
.
launch_done
model_worker_batch
)
# Update the future token ids map
...
...
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
def
resolve_batch_result
(
self
,
bid
:
int
):
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
if
launch_done
is
not
None
:
launch_done
.
wait
()
copy_done
.
synchronize
()
self
.
launch_done
.
wait
()
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment