Unverified Commit d658f049 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[overlap-spec] fix stop condition and trimming (#11819)

parent 57e25de7
......@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
return output
assert len(output) > 0
# NOTE: We can always assume the last token is the matched stop token
return output[:-1]
return output
......
......@@ -486,6 +486,8 @@ class Req:
# Check finish
self.tokenizer = None
self.finished_reason = None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self.finished_len = None
# Whether this request has finished output
self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true
......@@ -651,6 +653,13 @@ class Req:
spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
@property
def output_ids_through_stop(self) -> List[int]:
"""Get the output ids through the stop condition. Stop position is included."""
if self.finished_len is not None:
return self.output_ids[: self.finished_len]
return self.output_ids
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
......@@ -702,18 +711,20 @@ class Req:
def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None
output_ids = self.output_ids_through_stop
if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
)
self.surr_and_decode_ids = (
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
)
self.cur_decode_ids_len = len(self.output_ids)
self.cur_decode_ids_len = len(output_ids)
else:
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(self.output_ids)
self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(output_ids)
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
......@@ -760,55 +771,31 @@ class Req:
return False
def check_finished(self):
if self.finished():
return
if self.to_abort:
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
last_token_id = self.output_ids[-1]
def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
if self.sampling_params.ignore_eos:
return False
if not self.sampling_params.ignore_eos:
# Check stop token ids
matched_eos = False
# Check stop token ids
for i, token_id in enumerate(new_accepted_tokens):
if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
matched_eos |= token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids:
matched_eos |= last_token_id in self.eos_token_ids
matched_eos |= token_id in self.eos_token_ids
if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id
matched_eos |= token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= (
last_token_id in self.tokenizer.additional_stop_token_ids
)
matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
self.finished_len = matched_pos + 1
return True
if last_token_id > self.vocab_size or last_token_id < 0:
if self.sampling_params.stop_token_ids:
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
if self.eos_token_ids:
self.output_ids[-1] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return
return False
def _check_str_based_finish(self):
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
......@@ -820,7 +807,7 @@ class Req:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
return True
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
......@@ -829,6 +816,57 @@ class Req:
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return True
return False
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
for i, token_id in enumerate(new_accepted_tokens):
if token_id > self.vocab_size or token_id < 0:
offset = len(self.output_ids) - len(new_accepted_tokens) + i
if self.sampling_params.stop_token_ids:
self.output_ids[offset] = next(
iter(self.sampling_params.stop_token_ids)
)
if self.eos_token_ids:
self.output_ids[offset] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
self.finished_len = offset + 1
return True
return False
def check_finished(self, new_accepted_len: int = 1):
if self.finished():
return
if self.to_abort:
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
self.finished_len = self.sampling_params.max_new_tokens
return
if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
new_accepted_tokens = self.output_ids[-new_accepted_len:]
if self._check_token_based_finish(new_accepted_tokens):
return
if self._check_vocab_boundary_finish(new_accepted_tokens):
return
if self._check_str_based_finish():
return
def reset_for_retract(self):
......
......@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue
new_accepted_len = 1
if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id)
elif batch.is_v2_eagle:
# Only v2 eagle's output_ids are updated here.
req.output_ids.extend(next_token_id)
new_accepted_len = len(next_token_id)
req.check_finished(new_accepted_len)
req.check_finished()
if req.finished():
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here
......@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin:
# because of the one additional delayed token. This "continue" prevented the dummy output.
continue
req.finished_output = True
if req.finished_len is None:
req.finished_len = len(req.output_ids)
should_output = True
else:
if req.stream:
......@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin:
else:
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
# Exclude the tokens after stop condition
output_ids_ = req.output_ids_through_stop
req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset)
output_ids.append(req.output_ids[send_token_offset:])
req.send_token_offset = len(req.output_ids)
output_ids.append(output_ids_[send_token_offset:])
req.send_token_offset = len(output_ids_)
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
no_stop_trim.append(req.sampling_params.no_stop_trim)
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids))
completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment