Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
40d9b8ac
Unverified
Commit
40d9b8ac
authored
Apr 28, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 28, 2025
Browse files
Improve overlap scheduling (#5788)
parent
f0365820
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
23 deletions
+61
-23
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-5
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+25
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+9
-4
No files found.
python/sglang/srt/disaggregation/prefill.py
View file @
40d9b8ac
...
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
...
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
import
threading
from
collections
import
deque
from
collections
import
deque
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
...
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_batch
.
batch_is_full
=
False
self
.
running_batch
.
batch_is_full
=
False
def
process_batch_result_disagg_prefill
(
def
process_batch_result_disagg_prefill
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
...
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# wait
# wait
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
else
:
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
next_token_ids
=
result
.
next_token_ids
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
40d9b8ac
...
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import
copy
import
copy
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full
:
bool
=
False
batch_is_full
:
bool
=
False
# Events
launch_done
:
Optional
[
threading
.
Event
]
=
None
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
),
),
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
launch_done
=
self
.
launch_done
,
)
)
def
copy
(
self
):
def
copy
(
self
):
...
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
...
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
@
triton
.
jit
@
triton
.
jit
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
40d9b8ac
...
@@ -645,6 +645,7 @@ class Scheduler(
...
@@ -645,6 +645,7 @@ class Scheduler(
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
batch
.
launch_done
=
threading
.
Event
()
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
...
@@ -656,7 +657,7 @@ class Scheduler(
...
@@ -656,7 +657,7 @@ class Scheduler(
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
)
self
.
process_batch_result
(
tmp_batch
,
None
)
self
.
process_batch_result
(
tmp_batch
,
None
,
batch
.
launch_done
)
if
self
.
last_batch
:
if
self
.
last_batch
:
# Process the results of the last batch
# Process the results of the last batch
...
@@ -664,7 +665,10 @@ class Scheduler(
...
@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch
.
next_batch_sampling_info
=
(
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self
.
process_batch_result
(
tmp_batch
,
tmp_result
,
batch
.
launch_done
if
batch
else
None
)
elif
batch
is
None
:
elif
batch
is
None
:
# When the server is idle, do self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
check_memory
()
...
@@ -1417,14 +1421,15 @@ class Scheduler(
...
@@ -1417,14 +1421,15 @@ class Scheduler(
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_extend
():
elif
batch
.
forward_mode
.
is_extend
():
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_idle
():
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_batch_result
(
result
.
bid
)
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
self
.
current_stream
.
synchronize
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
40d9b8ac
from
__future__
import
annotations
from
__future__
import
annotations
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
...
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
EmbeddingBatchResult
,
EmbeddingBatchResult
,
GenerationBatchResult
,
GenerationBatchResult
,
ScheduleBatch
,
ScheduleBatch
,
Scheduler
,
)
)
...
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
"""
"""
def
process_batch_result_prefill
(
def
process_batch_result_prefill
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
skip_stream_req
=
None
skip_stream_req
=
None
...
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
...
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
)
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
,
)
)
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
...
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
def
process_batch_result_decode
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
logits_output
,
next_token_ids
,
bid
=
(
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
logits_output
,
...
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
...
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
# spec decoding handles output logprobs inside verify process.
...
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
self
.
log_decode_stats
()
self
.
log_decode_stats
()
def
add_input_logprob_return_values
(
def
add_input_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
i
:
int
,
req
:
Req
,
req
:
Req
,
output
:
LogitsProcessorOutput
,
output
:
LogitsProcessorOutput
,
...
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
def
add_logprob_return_values
(
def
add_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
i
:
int
,
req
:
Req
,
req
:
Req
,
pt
:
int
,
pt
:
int
,
...
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
return
num_input_logprobs
return
num_input_logprobs
def
stream_output
(
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
):
"""Stream the output to detokenizer."""
"""Stream the output to detokenizer."""
if
self
.
is_generation
:
if
self
.
is_generation
:
...
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output_embedding
(
reqs
)
self
.
stream_output_embedding
(
reqs
)
def
stream_output_generation
(
def
stream_output_generation
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
):
rids
=
[]
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
)
)
)
)
def
stream_output_embedding
(
self
,
reqs
:
List
[
Req
]):
def
stream_output_embedding
(
self
:
Scheduler
,
reqs
:
List
[
Req
]):
rids
=
[]
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
...
python/sglang/srt/managers/tp_worker.py
View file @
40d9b8ac
...
@@ -170,13 +170,13 @@ class TpModelWorker:
...
@@ -170,13 +170,13 @@ class TpModelWorker:
def
forward_batch_generation
(
def
forward_batch_generation
(
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
launch_done
:
launch_done
.
set
()
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
skip_sample
:
if
skip_sample
:
next_token_ids
=
None
next_token_ids
=
None
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
40d9b8ac
...
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
...
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
batch_pt
+=
1
batch_pt
+=
1
# Create event
# Create event
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
...
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
...
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
# Run forward
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
self
.
launch_done
model_worker_batch
)
)
# Update the future token ids map
# Update the future token ids map
...
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
...
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
def
resolve_batch_result
(
self
,
bid
:
int
):
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
if
launch_done
is
not
None
:
launch_done
.
wait
()
copy_done
.
synchronize
()
copy_done
.
synchronize
()
self
.
launch_done
.
wait
()
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment