Unverified Commit 002800f0 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Align vLLM's beam search implementation with HF generate (#857)

parent e15932bb
......@@ -75,10 +75,12 @@ class RequestOutput:
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
assert n <= len(seqs)
sorted_seqs = sorted(seqs,
key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.
......
......@@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
......@@ -52,6 +61,8 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
......@@ -65,6 +76,8 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = []
elif isinstance(stop, str):
......@@ -78,9 +91,11 @@ class SamplingParams:
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
elif self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None:
if self.n < 1:
......@@ -119,6 +134,20 @@ class SamplingParams:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
if self.best_of > 1:
......@@ -138,6 +167,8 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
......
......@@ -69,6 +69,9 @@ class SequenceData:
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
......@@ -155,6 +158,9 @@ class Sequence:
def get_len(self) -> int:
return self.data.get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
......@@ -170,14 +176,32 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data)
def fork(self, new_seq_id: int) -> "Sequence":
new_seq = copy.deepcopy(self)
new_seq.seq_id = new_seq_id
return new_seq
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
......@@ -203,35 +227,66 @@ class SequenceGroup:
arrival_time: float,
) -> None:
self.request_id = request_id
self.seqs = seqs
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences
# running.
return self.num_seqs(status=SequenceStatus.RUNNING)
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return self.seqs
return list(self.seqs_dict.values())
else:
return [seq for seq in self.seqs if seq.status == status]
return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f"Sequence {seq_id} not found.")
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
return self.seqs_dict[seq_id]
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs)
return all(seq.is_finished() for seq in self.get_seqs())
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})")
f"num_seqs={len(self.seqs_dict)})")
class SequenceGroupMetadata:
......@@ -266,7 +321,6 @@ class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
......@@ -276,26 +330,27 @@ class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float],
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"SequenceOutputs(seq_id={self.seq_id}, "
f"parent_seq_id={self.parent_seq_id}, "
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}), "
f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplemented
return (self.seq_id == other.seq_id
and self.parent_seq_id == other.parent_seq_id
return NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]
......@@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory
......@@ -260,7 +260,7 @@ class Worker:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment