Unverified Commit c06170cc authored by Yunfeng Bai's avatar Yunfeng Bai Committed by GitHub
Browse files

Add a flag to include stop string in output text (#1976)

parent 614856da
...@@ -682,9 +682,10 @@ class LLMEngine: ...@@ -682,9 +682,10 @@ class LLMEngine:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str): if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is if not sampling_params.include_stop_str_in_output:
# not included in the output. # Truncate the output text so that the stop string is
seq.output_text = seq.output_text[:-len(stop_str)] # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
if seq.get_last_token_id() in sampling_params.stop_token_ids: if seq.get_last_token_id() in sampling_params.stop_token_ids:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -70,6 +71,8 @@ class SamplingParams: ...@@ -70,6 +71,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens. the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
...@@ -103,6 +106,7 @@ class SamplingParams: ...@@ -103,6 +106,7 @@ class SamplingParams:
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
...@@ -140,6 +144,7 @@ class SamplingParams: ...@@ -140,6 +144,7 @@ class SamplingParams:
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
...@@ -227,24 +232,26 @@ class SamplingParams: ...@@ -227,24 +232,26 @@ class SamplingParams:
return SamplingType.RANDOM return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, " return (
f"best_of={self.best_of}, " f"SamplingParams(n={self.n}, "
f"presence_penalty={self.presence_penalty}, " f"best_of={self.best_of}, "
f"frequency_penalty={self.frequency_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"repetition_penalty={self.repetition_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, " f"repetition_penalty={self.repetition_penalty}, "
f"top_p={self.top_p}, " f"temperature={self.temperature}, "
f"top_k={self.top_k}, " f"top_p={self.top_p}, "
f"min_p={self.min_p}, " f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, " f"min_p={self.min_p}, "
f"length_penalty={self.length_penalty}, " f"use_beam_search={self.use_beam_search}, "
f"early_stopping={self.early_stopping}, " f"length_penalty={self.length_penalty}, "
f"stop={self.stop}, " f"early_stopping={self.early_stopping}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"stop_token_ids={self.stop_token_ids}, "
f"max_tokens={self.max_tokens}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"logprobs={self.logprobs}, " f"ignore_eos={self.ignore_eos}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"max_tokens={self.max_tokens}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"logprobs={self.logprobs}, "
"spaces_between_special_tokens=" f"prompt_logprobs={self.prompt_logprobs}, "
f"{self.spaces_between_special_tokens})") f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment