Unverified Commit fa4e0fb0 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Don't schedule spec tokens with prefill chunks (#33652)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent ce498a6d
...@@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output(): ...@@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output():
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
def test_no_spec_tokens_scheduled_for_prefill_chunks():
"""Test that draft tokens are ignored for prefill chunk requests.
When a request is being prefilled in chunks (chunked prefill), draft tokens
from `update_draft_token_ids` should be ignored until the prefill is complete.
The bug manifests when:
- A prefill chunk is scheduled
- Draft tokens are provided via update_draft_token_ids
- The next schedule has enough budget to include spec tokens
Without the fix, spec tokens would incorrectly be scheduled with the
remaining prefill tokens. With the fix, draft tokens are ignored for
prefill chunks.
"""
num_spec_tokens = 3
# Use budget of 50, with 80 token prompt:
# - First chunk: 50 tokens
# - Second chunk: 30 remaining + potentially 3 spec tokens = 33
# Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!)
# With fix: spec_token_ids cleared, so no spec tokens scheduled
scheduler = create_scheduler(
num_speculative_tokens=num_spec_tokens,
max_num_batched_tokens=50,
enable_chunked_prefill=True,
)
requests = create_requests(num_requests=1, num_tokens=80)
req = requests[0]
scheduler.add_request(req)
# First schedule - prefill chunk (50 of 80 tokens)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.num_scheduled_tokens[req.request_id] == 50
# Update from output (no sampled token since still prefilling)
req_to_index = {req.request_id: 0}
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Provide draft tokens while request is still in prefill.
# The fix ensures these are ignored for prefill chunks.
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# Second schedule - remaining 30 tokens of prefill
output = scheduler.schedule()
# KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens,
# NOT 33 (30 + 3 spec). Without the fix, this would be 33.
assert output.num_scheduled_tokens[req.request_id] == 30, (
f"Expected 30 tokens (remaining prefill only), "
f"got {output.num_scheduled_tokens[req.request_id]}. "
"Spec tokens should not be scheduled with prefill chunks."
)
# No spec tokens should be in the output
assert req.request_id not in output.scheduled_spec_decode_tokens, (
"Spec tokens should not be scheduled with prefill chunks"
)
# Update from output with a sampled token (prefill complete)
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Now provide draft tokens - should be accepted since prefill is complete
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# spec_token_ids SHOULD be set after prefill is complete
assert req.spec_token_ids == [1, 2, 3], (
f"spec_token_ids should be set after prefill, got {req.spec_token_ids}"
)
# Third schedule - decode phase with spec tokens
output = scheduler.schedule()
# 1 new token + 3 spec tokens = 4
assert output.num_scheduled_tokens[req.request_id] == 4
assert req.request_id in output.scheduled_spec_decode_tokens
assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens
def _assert_right_scheduler_output( def _assert_right_scheduler_output(
output: SchedulerOutput, output: SchedulerOutput,
num_requests: int, num_requests: int,
......
...@@ -17,33 +17,22 @@ class AsyncScheduler(Scheduler): ...@@ -17,33 +17,22 @@ class AsyncScheduler(Scheduler):
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
has_structured_output_requests |= request.use_structured_output if request.is_prefill_chunk:
pending_structured_output_tokens |= ( continue
scheduler_output.pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0 request.use_structured_output and request.num_output_placeholders > 0
) )
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if ( request.num_output_placeholders += 1 + cur_num_spec_tokens
request.num_computed_tokens # Add placeholders for the new draft/spec tokens.
== request.num_tokens # We will update the actual spec token ids in the worker process.
+ request.num_output_placeholders request.spec_token_ids = self._spec_token_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new draft/spec tokens.
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = self._spec_token_placeholders
scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output( def _update_request_with_output(
self, request: Request, new_token_ids: list[int] self, request: Request, new_token_ids: list[int]
......
...@@ -912,6 +912,12 @@ class Scheduler(SchedulerInterface): ...@@ -912,6 +912,12 @@ class Scheduler(SchedulerInterface):
for req_id, num_scheduled_token in num_scheduled_tokens.items(): for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id] request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
request.is_prefill_chunk = request.num_computed_tokens < (
request.num_tokens + request.num_output_placeholders
)
scheduler_output.has_structured_output_requests |= (
request.use_structured_output
)
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which # NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative # may be updated again in _update_from_output for speculative
...@@ -1562,6 +1568,12 @@ class Scheduler(SchedulerInterface): ...@@ -1562,6 +1568,12 @@ class Scheduler(SchedulerInterface):
# The request may have been finished. Skip. # The request may have been finished. Skip.
continue continue
if request.is_prefill_chunk:
# Ignore draft tokens for prefill chunks.
if request.spec_token_ids:
request.spec_token_ids = []
continue
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
......
...@@ -147,6 +147,9 @@ class Request: ...@@ -147,6 +147,9 @@ class Request:
# The number of tokens with prefix cache hits. # The number of tokens with prefix cache hits.
self.num_cached_tokens = -1 self.num_cached_tokens = -1
# True if this request is scheduled as a non-final prefill chunk.
self.is_prefill_chunk = False
# The number of NaNs in logits. A value greater than 0 # The number of NaNs in logits. A value greater than 0
# indicates that the output is corrupted # indicates that the output is corrupted
self.num_nans_in_logits = 0 self.num_nans_in_logits = 0
......
...@@ -16,21 +16,21 @@ class DraftTokensHandler: ...@@ -16,21 +16,21 @@ class DraftTokensHandler:
self.req_ids: list[str] = [] self.req_ids: list[str] = []
self.draft_tokens_np: np.ndarray | None = None self.draft_tokens_np: np.ndarray | None = None
self.num_draft_tokens: int = 0
def set_draft_tokens( def set_draft_tokens(
self, input_batch: InputBatch, draft_tokens: torch.Tensor self, input_batch: InputBatch, draft_tokens: torch.Tensor
) -> None: ) -> None:
self.req_ids = input_batch.req_ids
self.num_draft_tokens = draft_tokens.shape[1]
if not input_batch.has_structured_output_reqs: if not input_batch.has_structured_output_reqs:
# No draft token validation needs to be performed by # No draft token validation needs to be performed by
# the scheduler for this batch. # the scheduler for this batch.
if self.req_ids:
self.req_ids = []
self.draft_tokens_np = None self.draft_tokens_np = None
return return
# For spec decoding + structured outputs, we must transfer the # For spec decoding + structured outputs, we must transfer the
# draft tokens back to the scheduler for grammar validation. # draft tokens back to the scheduler for grammar validation.
self.req_ids = input_batch.req_ids
current_stream = torch.cuda.current_stream(self.device) current_stream = torch.cuda.current_stream(self.device)
self.copy_stream.wait_stream(current_stream) self.copy_stream.wait_stream(current_stream)
with torch.cuda.stream(self.copy_stream): with torch.cuda.stream(self.copy_stream):
...@@ -38,8 +38,10 @@ class DraftTokensHandler: ...@@ -38,8 +38,10 @@ class DraftTokensHandler:
self.copy_event.record() self.copy_event.record()
def get_draft_tokens(self) -> DraftTokenIds | None: def get_draft_tokens(self) -> DraftTokenIds | None:
if self.draft_tokens_np is None: if self.draft_tokens_np is not None:
return None self.copy_event.synchronize()
draft_token_ids = self.draft_tokens_np.tolist()
self.copy_event.synchronize() else:
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist()) # This case only happens when async scheduling is disabled.
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
return DraftTokenIds(self.req_ids, draft_token_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment