Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
40d9b8ac
"vscode:/vscode.git/clone" did not exist on "9c71bcb0bb825cba5cfb29f1a49871d8e4cb9117"
Unverified
Commit
40d9b8ac
authored
Apr 28, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 28, 2025
Browse files
Improve overlap scheduling (#5788)
parent
f0365820
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
23 deletions
+61
-23
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-5
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+25
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+9
-4
No files found.
python/sglang/srt/disaggregation/prefill.py
View file @
40d9b8ac
...
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
...
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
import
threading
from
collections
import
deque
from
collections
import
deque
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
...
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_batch
.
batch_is_full
=
False
self
.
running_batch
.
batch_is_full
=
False
def
process_batch_result_disagg_prefill
(
def
process_batch_result_disagg_prefill
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
...
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# wait
# wait
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
else
:
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
next_token_ids
=
result
.
next_token_ids
.
tolist
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
40d9b8ac
...
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import
copy
import
copy
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full
:
bool
=
False
batch_is_full
:
bool
=
False
# Events
launch_done
:
Optional
[
threading
.
Event
]
=
None
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
),
),
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
launch_done
=
self
.
launch_done
,
)
)
def
copy
(
self
):
def
copy
(
self
):
...
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
...
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
@
triton
.
jit
@
triton
.
jit
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
40d9b8ac
...
@@ -645,6 +645,7 @@ class Scheduler(
...
@@ -645,6 +645,7 @@ class Scheduler(
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
batch
.
launch_done
=
threading
.
Event
()
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
...
@@ -656,7 +657,7 @@ class Scheduler(
...
@@ -656,7 +657,7 @@ class Scheduler(
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
)
self
.
process_batch_result
(
tmp_batch
,
None
)
self
.
process_batch_result
(
tmp_batch
,
None
,
batch
.
launch_done
)
if
self
.
last_batch
:
if
self
.
last_batch
:
# Process the results of the last batch
# Process the results of the last batch
...
@@ -664,7 +665,10 @@ class Scheduler(
...
@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch
.
next_batch_sampling_info
=
(
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self
.
process_batch_result
(
tmp_batch
,
tmp_result
,
batch
.
launch_done
if
batch
else
None
)
elif
batch
is
None
:
elif
batch
is
None
:
# When the server is idle, do self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self
.
check_memory
()
self
.
check_memory
()
...
@@ -1417,14 +1421,15 @@ class Scheduler(
...
@@ -1417,14 +1421,15 @@ class Scheduler(
self
,
self
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_extend
():
elif
batch
.
forward_mode
.
is_extend
():
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
,
launch_done
)
elif
batch
.
forward_mode
.
is_idle
():
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_batch_result
(
result
.
bid
)
self
.
tp_worker
.
resolve_
last_
batch_result
(
launch_done
)
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
self
.
current_stream
.
synchronize
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
40d9b8ac
from
__future__
import
annotations
from
__future__
import
annotations
import
threading
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
...
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
EmbeddingBatchResult
,
EmbeddingBatchResult
,
GenerationBatchResult
,
GenerationBatchResult
,
ScheduleBatch
,
ScheduleBatch
,
Scheduler
,
)
)
...
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
"""
"""
def
process_batch_result_prefill
(
def
process_batch_result_prefill
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
skip_stream_req
=
None
skip_stream_req
=
None
...
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
...
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
)
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
,
)
)
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
...
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
def
process_batch_result_decode
(
self
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
logits_output
,
next_token_ids
,
bid
=
(
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
logits_output
,
...
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
...
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
# spec decoding handles output logprobs inside verify process.
...
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
self
.
log_decode_stats
()
self
.
log_decode_stats
()
def
add_input_logprob_return_values
(
def
add_input_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
i
:
int
,
req
:
Req
,
req
:
Req
,
output
:
LogitsProcessorOutput
,
output
:
LogitsProcessorOutput
,
...
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
def
add_logprob_return_values
(
def
add_logprob_return_values
(
self
,
self
:
Scheduler
,
i
:
int
,
i
:
int
,
req
:
Req
,
req
:
Req
,
pt
:
int
,
pt
:
int
,
...
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
return
num_input_logprobs
return
num_input_logprobs
def
stream_output
(
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
):
"""Stream the output to detokenizer."""
"""Stream the output to detokenizer."""
if
self
.
is_generation
:
if
self
.
is_generation
:
...
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
...
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output_embedding
(
reqs
)
self
.
stream_output_embedding
(
reqs
)
def
stream_output_generation
(
def
stream_output_generation
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
self
:
Scheduler
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
,
):
):
rids
=
[]
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
)
)
)
)
def
stream_output_embedding
(
self
,
reqs
:
List
[
Req
]):
def
stream_output_embedding
(
self
:
Scheduler
,
reqs
:
List
[
Req
]):
rids
=
[]
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
...
...
python/sglang/srt/managers/tp_worker.py
View file @
40d9b8ac
...
@@ -170,13 +170,13 @@ class TpModelWorker:
...
@@ -170,13 +170,13 @@ class TpModelWorker:
def
forward_batch_generation
(
def
forward_batch_generation
(
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
launch_done
:
launch_done
.
set
()
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
skip_sample
:
if
skip_sample
:
next_token_ids
=
None
next_token_ids
=
None
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
40d9b8ac
...
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
...
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
batch_pt
+=
1
batch_pt
+=
1
# Create event
# Create event
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
).
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
...
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
...
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
# Run forward
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
self
.
launch_done
model_worker_batch
)
)
# Update the future token ids map
# Update the future token ids map
...
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
...
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
def
resolve_batch_result
(
self
,
bid
:
int
):
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
if
launch_done
is
not
None
:
launch_done
.
wait
()
copy_done
.
synchronize
()
copy_done
.
synchronize
()
self
.
launch_done
.
wait
()
if
logits_output
.
next_token_logprobs
is
not
None
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment