Unverified Commit 44822d7f authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Preserve spec decoding uniform decode when scheduling (#29759)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 342c4f14
...@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance( ...@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve. # Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 85% acceptance rate at the end. # Heuristic: expect at least 82.5% acceptance rate at the end.
assert last_accept_rate > 0.85 assert last_accept_rate > 0.825
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler): ...@@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler):
# in this scheduling step. # in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids. # Add placeholders for the new tokens in spec_token_ids.
# Wwe will update the actual spec token ids in the worker process. # We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = ( scheduler_output.pending_structured_output_tokens = (
......
...@@ -236,6 +236,22 @@ class Scheduler(SchedulerInterface): ...@@ -236,6 +236,22 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
if (
request.num_output_placeholders > 0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# Since output placeholders are also included in the computed tokens
# count, we subtract (num_output_placeholders - 1) to remove any draft
# tokens, so that we can be sure no further steps are needed even if
# they are all rejected.
and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens
):
# Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations.
req_index += 1
continue
num_new_tokens = ( num_new_tokens = (
request.num_tokens_with_spec request.num_tokens_with_spec
+ request.num_output_placeholders + request.num_output_placeholders
...@@ -245,18 +261,10 @@ class Scheduler(SchedulerInterface): ...@@ -245,18 +261,10 @@ class Scheduler(SchedulerInterface):
num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
num_spec_placeholders = max(0, request.num_output_placeholders - 1) # Make sure the input position does not exceed the max model len.
max_total_tokens = min( # This is necessary when using spec decoding.
# Avoid scheduling tokens that we're sure won't will be needed based on
# request.max_tokens. For this calculation we assume placeholder
# speculated output tokens are rejected.
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
self.max_model_len,
)
num_new_tokens = min( num_new_tokens = min(
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
) )
# Schedule encoder inputs. # Schedule encoder inputs.
...@@ -799,15 +807,15 @@ class Scheduler(SchedulerInterface): ...@@ -799,15 +807,15 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = num_scheduled_tokens[req_id] - len(
spec_decode_tokens.get(req_id, ())
)
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
num_tokens = num_scheduled_tokens[req_id] - len(
spec_decode_tokens.get(req_id, ())
)
token_ids = req.all_token_ids[ token_ids = req.all_token_ids[
req.num_computed_tokens : req.num_computed_tokens + num_tokens req.num_computed_tokens : req.num_computed_tokens + num_tokens
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment