Unverified Commit 711edaf0 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize spec decoding + async scheduling, 1.5% Throughput improvement (#33612)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d367a73
...@@ -10,6 +10,11 @@ logger = init_logger(__name__) ...@@ -10,6 +10,11 @@ logger = init_logger(__name__)
class AsyncScheduler(Scheduler): class AsyncScheduler(Scheduler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# reusable read-only placeholder list for speculative decoding.
self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False has_structured_output_requests = False
...@@ -31,9 +36,9 @@ class AsyncScheduler(Scheduler): ...@@ -31,9 +36,9 @@ class AsyncScheduler(Scheduler):
# The request will generate a new token plus num_spec_tokens # The request will generate a new token plus num_spec_tokens
# in this scheduling step. # in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids. # Add placeholders for the new draft/spec tokens.
# We will update the actual spec token ids in the worker process. # We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens request.spec_token_ids = self._spec_token_placeholders
scheduler_output.has_structured_output_requests = has_structured_output_requests scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = ( scheduler_output.pending_structured_output_tokens = (
......
...@@ -487,9 +487,11 @@ class Scheduler(SchedulerInterface): ...@@ -487,9 +487,11 @@ class Scheduler(SchedulerInterface):
- request.num_output_placeholders - request.num_output_placeholders
) )
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. spec_token_ids = request.spec_token_ids
del request.spec_token_ids[num_scheduled_spec_tokens:] if len(spec_token_ids) > num_scheduled_spec_tokens:
scheduled_spec_decode_tokens[request_id] = request.spec_token_ids spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
...@@ -887,7 +889,8 @@ class Scheduler(SchedulerInterface): ...@@ -887,7 +889,8 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0 request.num_computed_tokens = 0
request.spec_token_ids.clear() if request.spec_token_ids:
request.spec_token_ids = []
request.num_preemptions += 1 request.num_preemptions += 1
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.PREEMPTED, timestamp) request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment