Unverified Commit 287b37cd authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix spec decoding edge case bugs (#31944)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 791b2fc3
...@@ -775,6 +775,7 @@ class Scheduler(SchedulerInterface): ...@@ -775,6 +775,7 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0 request.num_computed_tokens = 0
request.spec_token_ids.clear()
request.num_preemptions += 1 request.num_preemptions += 1
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.PREEMPTED, timestamp) request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
......
...@@ -446,6 +446,32 @@ class InputBatch: ...@@ -446,6 +446,32 @@ class InputBatch:
return req_index return req_index
def update_req_spec_token_ids(
self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
) -> None:
req_id = request.req_id
req_index = self.req_id_to_index[req_id]
cur_spec_token_ids = self.spec_token_ids[req_index]
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
cur_spec_token_ids.clear()
spec_token_ids = scheduled_spec_tokens.get(req_id, ())
num_spec_tokens = len(spec_token_ids)
request.prev_num_draft_len = num_spec_tokens
if not spec_token_ids:
return
# For async scheduling, token_ids_cpu assigned from
# spec_token_ids are placeholders and will be overwritten in
# _prepare_input_ids.
start_index = self.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
cur_spec_token_ids.extend(spec_token_ids)
def remove_request(self, req_id: str) -> int | None: def remove_request(self, req_id: str) -> int | None:
"""This method must always be followed by a call to condense(). """This method must always be followed by a call to condense().
......
...@@ -925,6 +925,7 @@ class GPUModelRunner( ...@@ -925,6 +925,7 @@ class GPUModelRunner(
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs req_data = scheduler_output.scheduled_cached_reqs
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
# Wait until valid_sampled_tokens_count is copied to cpu, # Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request. # then use it to update actual num_computed_tokens of each request.
...@@ -938,20 +939,20 @@ class GPUModelRunner( ...@@ -938,20 +939,20 @@ class GPUModelRunner(
num_output_tokens = req_data.num_output_tokens[i] num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
# prev_num_draft_len is used in async scheduling mode with if req_state.prev_num_draft_len and self.use_async_scheduling:
# spec decode. it indicates if need to update num_computed_tokens # prev_num_draft_len is used in async scheduling mode with
# of the request. for example: # spec decode. it indicates if need to update num_computed_tokens
# fist step: num_computed_tokens = 0, spec_tokens = [], # of the request. for example:
# prev_num_draft_len = 0. # fist step: num_computed_tokens = 0, spec_tokens = [],
# second step: num_computed_tokens = 100(prompt lenth), # prev_num_draft_len = 0.
# spec_tokens = [a,b], prev_num_draft_len = 0. # second step: num_computed_tokens = 100(prompt lenth),
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # spec_tokens = [a,b], prev_num_draft_len = 0.
# prev_num_draft_len = 2. # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# num_computed_tokens in first step and second step does't contain # prev_num_draft_len = 2.
# the spec tokens length, but in third step it contains the # num_computed_tokens in first step and second step does't contain
# spec tokens length. we only need to update num_computed_tokens # the spec tokens length, but in third step it contains the
# when prev_num_draft_len > 0. # spec tokens length. we only need to update num_computed_tokens
if req_state.prev_num_draft_len: # when prev_num_draft_len > 0.
if req_index is None: if req_index is None:
req_state.prev_num_draft_len = 0 req_state.prev_num_draft_len = 0
else: else:
...@@ -1035,34 +1036,13 @@ class GPUModelRunner( ...@@ -1035,34 +1036,13 @@ class GPUModelRunner(
self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
req_id, []
)
num_spec_tokens = len(spec_token_ids)
# For async scheduling, token_ids_cpu assigned from
# spec_token_ids are placeholders and will be overwritten in
# _prepare_input_ids.
if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index
] = spec_token_ids
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
self.input_batch.spec_token_ids[req_index].clear()
self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
for request in reqs_to_add: for request in reqs_to_add:
self.input_batch.add_request(request) self.input_batch.add_request(request)
self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens)
# Condense the batched states if there are gaps left by removed requests # Condense the batched states if there are gaps left by removed requests
self.input_batch.condense() self.input_batch.condense()
...@@ -1519,7 +1499,6 @@ class GPUModelRunner( ...@@ -1519,7 +1499,6 @@ class GPUModelRunner(
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None spec_decode_metadata = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else: else:
...@@ -1536,14 +1515,11 @@ class GPUModelRunner( ...@@ -1536,14 +1515,11 @@ class GPUModelRunner(
) in scheduler_output.scheduled_spec_decode_tokens.items(): ) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids) num_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = ( if (
len(draft_token_ids) self.input_batch.num_computed_tokens_cpu[req_idx]
if ( >= self.input_batch.num_prompt_tokens[req_idx]
self.input_batch.num_computed_tokens_cpu[req_idx] ):
>= self.input_batch.num_prompt_tokens[req_idx] num_decode_draft_tokens[req_idx] = len(draft_token_ids)
)
else -1
)
spec_decode_metadata = self._calc_spec_decode_metadata( spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens num_draft_tokens, cu_num_tokens
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment