Unverified Commit 5061b8fd authored by ybyang's avatar ybyang Committed by GitHub
Browse files

fix stop when stream (#11462)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarLiangsheng Yin <lsyincs@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent c8452551
...@@ -734,6 +734,40 @@ class Req: ...@@ -734,6 +734,40 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1
tail_len = min(tail_len, len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
"""
Check if the suffix of tail_str overlaps with any stop_str prefix
"""
if not self.sampling_params.stop_strs:
return False
tail_str = self.tail_str()
# Early return if tail_str is empty
if not tail_str:
return False
for stop_str in self.sampling_params.stop_strs:
if not stop_str:
continue
# Check if stop_str is contained in tail_str (fastest check first)
if stop_str in tail_str:
return True
# Check if tail_str suffix matches stop_str prefix
# Only check if stop_str is not empty, it's for stream output
min_len = min(len(tail_str), len(stop_str))
for i in range(1, min_len + 1):
if tail_str[-i:] == stop_str[:i]:
return True
return False
def check_finished(self): def check_finished(self):
if self.finished(): if self.finished():
return return
...@@ -785,9 +819,7 @@ class Req: ...@@ -785,9 +819,7 @@ class Req:
# Check stop strings # Check stop strings
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode( tail_str = self.tail_str()
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs: for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text: if stop_str in tail_str or stop_str in self.decoded_text:
......
...@@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin: ...@@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin:
stream_interval = ( stream_interval = (
req.sampling_params.stream_interval or self.stream_interval req.sampling_params.stream_interval or self.stream_interval
) )
# origin stream_interval logic
should_output = ( should_output = (
len(req.output_ids) % stream_interval == 1 len(req.output_ids) % stream_interval == 1
if not self.model_config.is_multimodal_gen if not self.model_config.is_multimodal_gen
and stream_interval > 1 and stream_interval > 1
else len(req.output_ids) % stream_interval == 0 else len(req.output_ids) % stream_interval == 0
) )
if should_output:
# check_match_stop_str_prefix if tail_str's suffix match stop_str prefix
should_output &= not req.check_match_stop_str_prefix()
else: else:
should_output = ( should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment