Unverified Commit 18b296fd authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] remove beam search from the core (#9105)

parent c8f26bb6
......@@ -142,11 +142,7 @@ class RequestOutput:
else:
# Get the top-n sequences.
n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
......
......@@ -10,7 +10,6 @@ import torch
from pydantic import BaseModel
from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -23,7 +22,6 @@ class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
......@@ -134,16 +132,6 @@ class SamplingParams(
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
......@@ -193,9 +181,6 @@ class SamplingParams(
top_k: int = -1
min_p: float = 0.0
seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False
......@@ -238,9 +223,6 @@ class SamplingParams(
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
......@@ -280,9 +262,6 @@ class SamplingParams(
top_k=top_k,
min_p=min_p,
seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
......@@ -334,20 +313,13 @@ class SamplingParams(
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
self._verify_args()
if self.use_beam_search:
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
raise ValueError(
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa
)
self._verify_beam_search()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self._all_stop_token_ids = set(self.stop_token_ids)
......@@ -417,31 +389,6 @@ class SamplingParams(
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1:
......@@ -476,8 +423,6 @@ class SamplingParams(
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
......@@ -514,9 +459,6 @@ class SamplingParams(
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
......@@ -542,3 +484,4 @@ class BeamSearchParams(
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
......@@ -577,25 +577,6 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 1.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
......@@ -809,25 +790,18 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else:
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
def get_seqs(
self,
......
......@@ -1361,3 +1361,22 @@ class AtomicCounter:
@property
def value(self):
return self._value
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
......@@ -453,9 +453,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment