sampling_params.py 17.3 KB
Newer Older
1
"""Sampling parameters for text generation."""
2
import copy
3
4
from enum import IntEnum
from functools import cached_property
5
from typing import Any, Callable, Dict, List, Optional, Union
6

7
import torch
8
9
from pydantic import Field
from typing_extensions import Annotated
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
import vllm.envs as envs
from vllm.logger import init_logger

logger = init_logger(__name__)

16
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18

19
20
21
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
22
23
    RANDOM_SEED = 2
    BEAM = 3
24
25


26
27
28
29
30
31
32
33
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
                        Callable[[List[int], List[int], torch.Tensor],
                                 torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
34
35


Woosuk Kwon's avatar
Woosuk Kwon committed
36
class SamplingParams:
37
38
39
40
41
42
43
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.

    Args:
44
45
46
47
48
49
        n: Number of output sequences to return for the given prompt.
        best_of: Number of output sequences that are generated from the prompt.
            From these `best_of` sequences, the top `n` sequences are returned.
            `best_of` must be greater than or equal to `n`. This is treated as
            the beam width when `use_beam_search` is True. By default, `best_of`
            is set to `n`.
50
51
52
53
54
55
56
57
        presence_penalty: Float that penalizes new tokens based on whether they
            appear in the generated text so far. Values > 0 encourage the model
            to use new tokens, while values < 0 encourage the model to repeat
            tokens.
        frequency_penalty: Float that penalizes new tokens based on their
            frequency in the generated text so far. Values > 0 encourage the
            model to use new tokens, while values < 0 encourage the model to
            repeat tokens.
ljss's avatar
ljss committed
58
        repetition_penalty: Float that penalizes new tokens based on whether
59
60
61
            they appear in the prompt and the generated text so far. Values > 1
            encourage the model to use new tokens, while values < 1 encourage
            the model to repeat tokens.
62
63
64
65
66
67
68
        temperature: Float that controls the randomness of the sampling. Lower
            values make the model more deterministic, while higher values make
            the model more random. Zero means greedy sampling.
        top_p: Float that controls the cumulative probability of the top tokens
            to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
        top_k: Integer that controls the number of top tokens to consider. Set
            to -1 to consider all tokens.
Roy's avatar
Roy committed
69
70
71
        min_p: Float that represents the minimum probability for a token to be
            considered, relative to the probability of the most likely token.
            Must be in [0, 1]. Set to 0 to disable this.
Nick Hill's avatar
Nick Hill committed
72
        seed: Random seed to use for the generation.
73
        use_beam_search: Whether to use beam search instead of sampling.
74
75
76
77
78
79
80
81
82
        length_penalty: Float that penalizes sequences based on their length.
            Used in beam search.
        early_stopping: Controls the stopping condition for beam search. It
            accepts the following values: `True`, where the generation stops as
            soon as there are `best_of` complete candidates; `False`, where an
            heuristic is applied and the generation stops when is it very
            unlikely to find better candidates; `"never"`, where the beam search
            procedure only stops when there cannot be better candidates
            (canonical beam search algorithm).
83
84
        stop: List of strings that stop the generation when they are generated.
            The returned output will not contain the stop strings.
85
86
        stop_token_ids: List of tokens that stop the generation when they are
            generated. The returned output will contain the stop tokens unless
87
            the stop tokens are special tokens.
88
89
        include_stop_str_in_output: Whether to include the stop strings in
            output text. Defaults to False.
90
91
        ignore_eos: Whether to ignore the EOS token and continue generating
            tokens after the EOS token is generated.
92
        max_tokens: Maximum number of tokens to generate per output sequence.
93
94
        min_tokens: Minimum number of tokens to generate per output sequence
            before EOS or stop_token_ids can be generated
95
        logprobs: Number of log probabilities to return per output token.
96
97
98
99
100
101
            Note that the implementation follows the OpenAI API: The return
            result includes the log probabilities on the `logprobs` most likely
            tokens, as well the chosen tokens. The API will always return the
            log probability of the sampled token, so there  may be up to
            `logprobs+1` elements in the response.
        prompt_logprobs: Number of log probabilities to return per prompt token.
102
        detokenize: Whether to detokenize the output. Defaults to True.
103
        skip_special_tokens: Whether to skip special tokens in the output.
104
105
        spaces_between_special_tokens: Whether to add spaces between special
            tokens in the output.  Defaults to True.
106
        logits_processors: List of functions that modify logits based on
107
108
            previously generated tokens, and optionally prompt tokens as
            a first argument.
109
110
111
        truncate_prompt_tokens: If set to an integer k, will use only the last k
            tokens from the prompt (i.e., left truncation). Defaults to None
            (i.e., no truncation).
112
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
116
        n: int = 1,
117
        best_of: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
ljss's avatar
ljss committed
120
        repetition_penalty: float = 1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
124
        min_p: float = 0.0,
Nick Hill's avatar
Nick Hill committed
125
        seed: Optional[int] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
        use_beam_search: bool = False,
127
128
        length_penalty: float = 1.0,
        early_stopping: Union[bool, str] = False,
129
130
        stop: Optional[Union[str, List[str]]] = None,
        stop_token_ids: Optional[List[int]] = None,
131
        include_stop_str_in_output: bool = False,
132
        ignore_eos: bool = False,
133
        max_tokens: Optional[int] = 16,
134
        min_tokens: int = 0,
Zhuohan Li's avatar
Zhuohan Li committed
135
        logprobs: Optional[int] = None,
136
        prompt_logprobs: Optional[int] = None,
137
        detokenize: bool = True,
138
        skip_special_tokens: bool = True,
139
        spaces_between_special_tokens: bool = True,
140
        logits_processors: Optional[List[LogitsProcessor]] = None,
141
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
    ) -> None:
        self.n = n
144
        self.best_of = best_of if best_of is not None else n
145
146
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
ljss's avatar
ljss committed
147
        self.repetition_penalty = repetition_penalty
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
        self.temperature = temperature
        self.top_p = top_p
Woosuk Kwon's avatar
Woosuk Kwon committed
150
        self.top_k = top_k
Roy's avatar
Roy committed
151
        self.min_p = min_p
152
153
154
155
        if seed == -1:
            self.seed = None
        else:
            self.seed = seed
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        self.use_beam_search = use_beam_search
157
158
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
159
160
161
162
163
164
        if stop is None:
            self.stop = []
        elif isinstance(stop, str):
            self.stop = [stop]
        else:
            self.stop = list(stop)
165
166
167
168
        if stop_token_ids is None:
            self.stop_token_ids = []
        else:
            self.stop_token_ids = list(stop_token_ids)
169
        self.ignore_eos = ignore_eos
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        self.max_tokens = max_tokens
171
        self.min_tokens = min_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        self.logprobs = logprobs
173
        self.prompt_logprobs = prompt_logprobs
174
175
176
177
        # NOTE: This parameter is only exposed at the engine level for now.
        # It is not exposed in the OpenAI API server, as the OpenAI API does
        # not support returning only a list of token IDs.
        self.detokenize = detokenize
178
        self.skip_special_tokens = skip_special_tokens
179
        self.spaces_between_special_tokens = spaces_between_special_tokens
180
        self.logits_processors = logits_processors
181
        self.include_stop_str_in_output = include_stop_str_in_output
182
        self.truncate_prompt_tokens = truncate_prompt_tokens
183
184
185
186
187
188
189
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
        if self.stop and not include_stop_str_in_output:
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
        else:
            self.output_text_buffer_length = 0

190
191
        self._verify_args()
        if self.use_beam_search:
192
193
194
195
196
197
198
            if not envs.VLLM_NO_DEPRECATION_WARNING:
                logger.warning(
                    "[IMPORTANT] We plan to discontinue the support for beam "
                    "search in the next major release. Please refer to "
                    "https://github.com/vllm-project/vllm/issues/6226 for "
                    "more information. Set VLLM_NO_DEPRECATION_WARNING=1 to "
                    "suppress this warning.")
199
            self._verify_beam_search()
200
201
202
203
        else:
            self._verify_non_beam_search()
            if self.temperature < _SAMPLING_EPS:
                # Zero temperature means greedy sampling.
204
205
                self.top_p = 1.0
                self.top_k = -1
Roy's avatar
Roy committed
206
                self.min_p = 0.0
207
                self._verify_greedy_sampling()
208
209
        # eos_token_id is added to this by the engine
        self.all_stop_token_ids = set(self.stop_token_ids)
210
211
212
213

    def _verify_args(self) -> None:
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
214
215
216
        if self.best_of < self.n:
            raise ValueError(f"best_of must be greater than or equal to n, "
                             f"got n={self.n} and best_of={self.best_of}.")
217
218
219
220
221
222
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
ljss's avatar
ljss committed
223
224
225
        if not 0.0 < self.repetition_penalty <= 2.0:
            raise ValueError("repetition_penalty must be in (0, 2], got "
                             f"{self.repetition_penalty}.")
226
227
228
229
230
231
232
233
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
Roy's avatar
Roy committed
234
235
236
        if not 0.0 <= self.min_p <= 1.0:
            raise ValueError("min_p must be in [0, 1], got "
                             f"{self.min_p}.")
237
        if self.max_tokens is not None and self.max_tokens < 1:
238
239
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
240
241
242
243
244
245
246
        if self.min_tokens < 0:
            raise ValueError(f"min_tokens must be greater than or equal to 0, "
                             f"got {self.min_tokens}.")
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
                f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
Zhuohan Li's avatar
Zhuohan Li committed
247
        if self.logprobs is not None and self.logprobs < 0:
248
249
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")
250
251
252
        if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
            raise ValueError(f"prompt_logprobs must be non-negative, got "
                             f"{self.prompt_logprobs}.")
253
254
255
256
        if (self.truncate_prompt_tokens is not None
                and self.truncate_prompt_tokens < 1):
            raise ValueError(f"truncate_prompt_tokens must be >= 1, "
                             f"got {self.truncate_prompt_tokens}")
257
258
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
259
260
261
262
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
                "Set detokenize=True to use stop.")
263

264
    def _verify_beam_search(self) -> None:
265
266
267
        if self.best_of == 1:
            raise ValueError("best_of must be greater than 1 when using beam "
                             f"search. Got {self.best_of}.")
268
        if self.temperature > _SAMPLING_EPS:
269
            raise ValueError("temperature must be 0 when using beam search.")
270
        if self.top_p < 1.0 - _SAMPLING_EPS:
271
272
273
            raise ValueError("top_p must be 1 when using beam search.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using beam search.")
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        if self.early_stopping not in [True, False, "never"]:
            raise ValueError(
                f"early_stopping must be True, False, or 'never', "
                f"got {self.early_stopping}.")

    def _verify_non_beam_search(self) -> None:
        if self.early_stopping is not False:
            raise ValueError("early_stopping is not effective and must be "
                             "False when not using beam search.")
        if (self.length_penalty < 1.0 - _SAMPLING_EPS
                or self.length_penalty > 1.0 + _SAMPLING_EPS):
            raise ValueError(
                "length_penalty is not effective and must be the "
                "default value of 1.0 when not using beam search.")
288
289

    def _verify_greedy_sampling(self) -> None:
290
291
292
        if self.best_of > 1:
            raise ValueError("best_of must be 1 when using greedy sampling."
                             f"Got {self.best_of}.")
293

294
    def update_from_generation_config(
295
296
297
            self,
            generation_config: Dict[str, Any],
            model_eos_token_id: Optional[int] = None) -> None:
298
        """Update if there are non-default values from generation_config"""
299
300
301
302
303
304

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
            self.all_stop_token_ids.add(model_eos_token_id)

305
        # Update eos_token_id for generation
306
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
307
            # it can be either int or list of int
308
309
310
311
312
313
314
315
316
317
318
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
                self.all_stop_token_ids.update(eos_ids)
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
319

320
321
322
323
324
325
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.use_beam_search:
            return SamplingType.BEAM
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
326
327
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
328
329
        return SamplingType.RANDOM

330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def clone(self) -> "SamplingParams":
        """Deep copy excluding LogitsProcessor objects.

        LogitsProcessor objects are excluded because they may contain an
        arbitrary, nontrivial amount of data.
        See https://github.com/vllm-project/vllm/issues/3087
        """

        logit_processor_refs = None if self.logits_processors is None else {
            id(lp): lp
            for lp in self.logits_processors
        }
        return copy.deepcopy(self, memo=logit_processor_refs)

344
    def __repr__(self) -> str:
345
346
347
348
349
350
351
352
353
354
        return (
            f"SamplingParams(n={self.n}, "
            f"best_of={self.best_of}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
355
            f"seed={self.seed}, "
356
357
358
359
360
361
362
363
            f"use_beam_search={self.use_beam_search}, "
            f"length_penalty={self.length_penalty}, "
            f"early_stopping={self.early_stopping}, "
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
364
            f"min_tokens={self.min_tokens}, "
365
366
367
368
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
369
370
            f"{self.spaces_between_special_tokens}, "
            f"truncate_prompt_tokens={self.truncate_prompt_tokens})")