Unverified Commit 18b296fd authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] remove beam search from the core (#9105)

parent c8f26bb6
...@@ -142,10 +142,6 @@ class RequestOutput: ...@@ -142,10 +142,6 @@ class RequestOutput:
else: else:
# Get the top-n sequences. # Get the top-n sequences.
n = sampling_params.n n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob() sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
......
...@@ -10,7 +10,6 @@ import torch ...@@ -10,7 +10,6 @@ import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -23,7 +22,6 @@ class SamplingType(IntEnum): ...@@ -23,7 +22,6 @@ class SamplingType(IntEnum):
GREEDY = 0 GREEDY = 0
RANDOM = 1 RANDOM = 1
RANDOM_SEED = 2 RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
...@@ -134,16 +132,6 @@ class SamplingParams( ...@@ -134,16 +132,6 @@ class SamplingParams(
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation. seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated. stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
...@@ -193,9 +181,6 @@ class SamplingParams( ...@@ -193,9 +181,6 @@ class SamplingParams(
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
seed: Optional[int] = None seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False ignore_eos: bool = False
...@@ -238,9 +223,6 @@ class SamplingParams( ...@@ -238,9 +223,6 @@ class SamplingParams(
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None, seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
...@@ -280,9 +262,6 @@ class SamplingParams( ...@@ -280,9 +262,6 @@ class SamplingParams(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
seed=seed, seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
...@@ -334,14 +313,7 @@ class SamplingParams( ...@@ -334,14 +313,7 @@ class SamplingParams(
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
self._verify_args() self._verify_args()
if self.use_beam_search:
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
raise ValueError(
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa
)
self._verify_beam_search()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling. # Zero temperature means greedy sampling.
self.top_p = 1.0 self.top_p = 1.0
...@@ -417,31 +389,6 @@ class SamplingParams( ...@@ -417,31 +389,6 @@ class SamplingParams(
RequestOutputKind.DELTA): RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA") raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int) assert isinstance(self.best_of, int)
if self.best_of > 1: if self.best_of > 1:
...@@ -476,8 +423,6 @@ class SamplingParams( ...@@ -476,8 +423,6 @@ class SamplingParams(
@cached_property @cached_property
def sampling_type(self) -> SamplingType: def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY return SamplingType.GREEDY
if self.seed is not None: if self.seed is not None:
...@@ -514,9 +459,6 @@ class SamplingParams( ...@@ -514,9 +459,6 @@ class SamplingParams(
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, " f"min_p={self.min_p}, "
f"seed={self.seed}, " f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, " f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
...@@ -542,3 +484,4 @@ class BeamSearchParams( ...@@ -542,3 +484,4 @@ class BeamSearchParams(
max_tokens: int max_tokens: int
ignore_eos: bool = False ignore_eos: bool = False
temperature: float = 0.0 temperature: float = 0.0
length_penalty: float = 1.0
...@@ -577,25 +577,6 @@ class Sequence: ...@@ -577,25 +577,6 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 1.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
...@@ -809,13 +790,6 @@ class SequenceGroup: ...@@ -809,13 +790,6 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else:
if self.sampling_params: if self.sampling_params:
best_of = self.sampling_params.best_of best_of = self.sampling_params.best_of
assert isinstance(best_of, int) assert isinstance(best_of, int)
......
...@@ -1361,3 +1361,22 @@ class AtomicCounter: ...@@ -1361,3 +1361,22 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
...@@ -453,9 +453,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -453,9 +453,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.") "backend.")
best_of.append(sampling_params.best_of) best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None: if sampling_params.logprobs is not None:
raise NotImplementedError( raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.") "logprobs is not currently supported by the TPU backend.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment