Unverified Commit 5061b8fd authored by ybyang's avatar ybyang Committed by GitHub
Browse files

fix stop when stream (#11462)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarLiangsheng Yin <lsyincs@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent c8452551
......@@ -734,6 +734,40 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1
tail_len = min(tail_len, len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
"""
Check if the suffix of tail_str overlaps with any stop_str prefix
"""
if not self.sampling_params.stop_strs:
return False
tail_str = self.tail_str()
# Early return if tail_str is empty
if not tail_str:
return False
for stop_str in self.sampling_params.stop_strs:
if not stop_str:
continue
# Check if stop_str is contained in tail_str (fastest check first)
if stop_str in tail_str:
return True
# Check if tail_str suffix matches stop_str prefix
# Only check if stop_str is not empty, it's for stream output
min_len = min(len(tail_str), len(stop_str))
for i in range(1, min_len + 1):
if tail_str[-i:] == stop_str[:i]:
return True
return False
def check_finished(self):
if self.finished():
return
......@@ -785,9 +819,7 @@ class Req:
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
tail_str = self.tail_str()
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
......
......@@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin:
stream_interval = (
req.sampling_params.stream_interval or self.stream_interval
)
# origin stream_interval logic
should_output = (
len(req.output_ids) % stream_interval == 1
if not self.model_config.is_multimodal_gen
and stream_interval > 1
else len(req.output_ids) % stream_interval == 0
)
if should_output:
# check_match_stop_str_prefix if tail_str's suffix match stop_str prefix
should_output &= not req.check_match_stop_str_prefix()
else:
should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment