Unverified Commit d658f049 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[overlap-spec] fix stop condition and trimming (#11819)

parent 57e25de7
...@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss: if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
return output return output
assert len(output) > 0 assert len(output) > 0
# NOTE: We can always assume the last token is the matched stop token
return output[:-1] return output[:-1]
return output return output
......
...@@ -486,6 +486,8 @@ class Req: ...@@ -486,6 +486,8 @@ class Req:
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self.finished_len = None
# Whether this request has finished output # Whether this request has finished output
self.finished_output = None self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true # If we want to abort the request in the middle of the event loop, set this to true
...@@ -651,6 +653,13 @@ class Req: ...@@ -651,6 +653,13 @@ class Req:
spec_alg = get_global_server_args().speculative_algorithm spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None return self.sampling_params.max_new_tokens == 0 and spec_alg is None
@property
def output_ids_through_stop(self) -> List[int]:
"""Get the output ids through the stop condition. Stop position is included."""
if self.finished_len is not None:
return self.output_ids[: self.finished_len]
return self.output_ids
def add_latency(self, stage: RequestStage): def add_latency(self, stage: RequestStage):
if self.metrics_collector is None: if self.metrics_collector is None:
return return
...@@ -702,18 +711,20 @@ class Req: ...@@ -702,18 +711,20 @@ class Req:
def init_incremental_detokenize(self): def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None first_iter = self.surr_offset is None or self.read_offset is None
output_ids = self.output_ids_through_stop
if first_iter: if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded) self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max( self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0 self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
) )
self.surr_and_decode_ids = ( self.surr_and_decode_ids = (
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
) )
self.cur_decode_ids_len = len(self.output_ids) self.cur_decode_ids_len = len(output_ids)
else: else:
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :]) self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(self.output_ids) self.cur_decode_ids_len = len(output_ids)
return self.surr_and_decode_ids, self.read_offset - self.surr_offset return self.surr_and_decode_ids, self.read_offset - self.surr_offset
...@@ -760,55 +771,31 @@ class Req: ...@@ -760,55 +771,31 @@ class Req:
return False return False
def check_finished(self): def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
if self.finished(): if self.sampling_params.ignore_eos:
return return False
if self.to_abort:
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
last_token_id = self.output_ids[-1]
if not self.sampling_params.ignore_eos: # Check stop token ids
matched_eos = False matched_eos = False
# Check stop token ids for i, token_id in enumerate(new_accepted_tokens):
if self.sampling_params.stop_token_ids: if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids matched_eos |= token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids: if self.eos_token_ids:
matched_eos |= last_token_id in self.eos_token_ids matched_eos |= token_id in self.eos_token_ids
if self.tokenizer is not None: if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id matched_eos |= token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids: if self.tokenizer.additional_stop_token_ids:
matched_eos |= ( matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
last_token_id in self.tokenizer.additional_stop_token_ids
)
if matched_eos: if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
return matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
self.finished_len = matched_pos + 1
return True
if last_token_id > self.vocab_size or last_token_id < 0: return False
if self.sampling_params.stop_token_ids:
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
if self.eos_token_ids:
self.output_ids[-1] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return
def _check_str_based_finish(self):
if ( if (
len(self.sampling_params.stop_strs) > 0 len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0 or len(self.sampling_params.stop_regex_strs) > 0
...@@ -820,7 +807,7 @@ class Req: ...@@ -820,7 +807,7 @@ class Req:
for stop_str in self.sampling_params.stop_strs: for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text: if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return return True
# Check stop regex # Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0: if len(self.sampling_params.stop_regex_strs) > 0:
...@@ -829,6 +816,57 @@ class Req: ...@@ -829,6 +816,57 @@ class Req:
self.finished_reason = FINISHED_MATCHED_REGEX( self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str matched=stop_regex_str
) )
return True
return False
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
for i, token_id in enumerate(new_accepted_tokens):
if token_id > self.vocab_size or token_id < 0:
offset = len(self.output_ids) - len(new_accepted_tokens) + i
if self.sampling_params.stop_token_ids:
self.output_ids[offset] = next(
iter(self.sampling_params.stop_token_ids)
)
if self.eos_token_ids:
self.output_ids[offset] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
self.finished_len = offset + 1
return True
return False
def check_finished(self, new_accepted_len: int = 1):
if self.finished():
return
if self.to_abort:
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
self.finished_len = self.sampling_params.max_new_tokens
return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
new_accepted_tokens = self.output_ids[-new_accepted_len:]
if self._check_token_based_finish(new_accepted_tokens):
return
if self._check_vocab_boundary_finish(new_accepted_tokens):
return
if self._check_str_based_finish():
return return
def reset_for_retract(self): def reset_for_retract(self):
......
...@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin: ...@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin:
self.token_to_kv_pool_allocator.free(indices_to_free) self.token_to_kv_pool_allocator.free(indices_to_free)
continue continue
new_accepted_len = 1
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
elif batch.is_v2_eagle: elif batch.is_v2_eagle:
# Only v2 eagle's output_ids are updated here. # Only v2 eagle's output_ids are updated here.
req.output_ids.extend(next_token_id) req.output_ids.extend(next_token_id)
new_accepted_len = len(next_token_id)
req.check_finished(new_accepted_len)
req.check_finished()
if req.finished(): if req.finished():
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here # FIXME(lsyin): fix the messy logic here
...@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin: ...@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin:
# because of the one additional delayed token. This "continue" prevented the dummy output. # because of the one additional delayed token. This "continue" prevented the dummy output.
continue continue
req.finished_output = True req.finished_output = True
if req.finished_len is None:
req.finished_len = len(req.output_ids)
should_output = True should_output = True
else: else:
if req.stream: if req.stream:
...@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin: ...@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin:
else: else:
decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
# Exclude the tokens after stop condition
output_ids_ = req.output_ids_through_stop
req.send_decode_id_offset = len(decode_ids) req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset) read_offsets.append(read_offset)
output_ids.append(req.output_ids[send_token_offset:]) output_ids.append(output_ids_[send_token_offset:])
req.send_token_offset = len(req.output_ids) req.send_token_offset = len(output_ids_)
skip_special_tokens.append(req.sampling_params.skip_special_tokens) skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append( spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
no_stop_trim.append(req.sampling_params.no_stop_trim) no_stop_trim.append(req.sampling_params.no_stop_trim)
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids)) completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment