sampling_params.py 24 KB
Newer Older
1
"""Sampling parameters for text generation."""
2
import copy
3
from dataclasses import dataclass
4
from enum import Enum, IntEnum
5
from functools import cached_property
6
from typing import Any, Callable, Dict, List, Optional, Set, Union
7

8
import msgspec
9
import torch
10
from pydantic import BaseModel
11
from typing_extensions import Annotated
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
import vllm.envs as envs
14
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)

18
_SAMPLING_EPS = 1e-5
19
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21

22
23
24
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
25
26
    RANDOM_SEED = 2
    BEAM = 3
27
28


29
30
31
32
33
34
35
36
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
                        Callable[[List[int], List[int], torch.Tensor],
                                 torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
37
38


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
    """One of these fields will be used to build a logit processor."""
    json: Optional[Union[str, Dict]] = None
    regex: Optional[str] = None
    choice: Optional[List[str]] = None
    grammar: Optional[str] = None
    json_object: Optional[bool] = None
    """These are other options that can be set"""
    backend: Optional[str] = None
    whitespace_pattern: Optional[str] = None

    @staticmethod
    def from_optional(
        json: Optional[Union[Dict, BaseModel, str]],
        regex: Optional[str] = None,
        choice: Optional[List[str]] = None,
        grammar: Optional[str] = None,
        json_object: Optional[bool] = None,
        backend: Optional[str] = None,
        whitespace_pattern: Optional[str] = None,
    ) -> "GuidedDecodingParams":
        # Extract json schemas from pydantic models
        if isinstance(json, (BaseModel, type(BaseModel))):
            json = json.model_json_schema()
        return GuidedDecodingParams(
            json=json,
            regex=regex,
            choice=choice,
            grammar=grammar,
            json_object=json_object,
            backend=backend,
            whitespace_pattern=whitespace_pattern,
        )

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
        guide_count = sum([
            self.json is not None, self.regex is not None, self.choice
            is not None, self.grammar is not None, self.json_object is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding but multiple are "
                f"specified: {self.__dict__}")


87
88
89
90
91
92
93
94
95
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
    # Do not return intermediate RequestOuputs
    FINAL_ONLY = 2


96
97
98
99
100
class SamplingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        # required for @cached_property.
        dict=True):  # type: ignore[call-arg]
101
102
103
104
105
106
107
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.

    Args:
108
109
110
111
112
113
        n: Number of output sequences to return for the given prompt.
        best_of: Number of output sequences that are generated from the prompt.
            From these `best_of` sequences, the top `n` sequences are returned.
            `best_of` must be greater than or equal to `n`. This is treated as
            the beam width when `use_beam_search` is True. By default, `best_of`
            is set to `n`.
114
115
116
117
118
119
120
121
        presence_penalty: Float that penalizes new tokens based on whether they
            appear in the generated text so far. Values > 0 encourage the model
            to use new tokens, while values < 0 encourage the model to repeat
            tokens.
        frequency_penalty: Float that penalizes new tokens based on their
            frequency in the generated text so far. Values > 0 encourage the
            model to use new tokens, while values < 0 encourage the model to
            repeat tokens.
ljss's avatar
ljss committed
122
        repetition_penalty: Float that penalizes new tokens based on whether
123
124
125
            they appear in the prompt and the generated text so far. Values > 1
            encourage the model to use new tokens, while values < 1 encourage
            the model to repeat tokens.
126
127
128
129
130
131
132
        temperature: Float that controls the randomness of the sampling. Lower
            values make the model more deterministic, while higher values make
            the model more random. Zero means greedy sampling.
        top_p: Float that controls the cumulative probability of the top tokens
            to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
        top_k: Integer that controls the number of top tokens to consider. Set
            to -1 to consider all tokens.
Roy's avatar
Roy committed
133
134
135
        min_p: Float that represents the minimum probability for a token to be
            considered, relative to the probability of the most likely token.
            Must be in [0, 1]. Set to 0 to disable this.
Nick Hill's avatar
Nick Hill committed
136
        seed: Random seed to use for the generation.
137
        use_beam_search: Whether to use beam search instead of sampling.
138
139
140
141
142
143
144
145
146
        length_penalty: Float that penalizes sequences based on their length.
            Used in beam search.
        early_stopping: Controls the stopping condition for beam search. It
            accepts the following values: `True`, where the generation stops as
            soon as there are `best_of` complete candidates; `False`, where an
            heuristic is applied and the generation stops when is it very
            unlikely to find better candidates; `"never"`, where the beam search
            procedure only stops when there cannot be better candidates
            (canonical beam search algorithm).
147
148
        stop: List of strings that stop the generation when they are generated.
            The returned output will not contain the stop strings.
149
150
        stop_token_ids: List of tokens that stop the generation when they are
            generated. The returned output will contain the stop tokens unless
151
            the stop tokens are special tokens.
152
153
        include_stop_str_in_output: Whether to include the stop strings in
            output text. Defaults to False.
154
155
        ignore_eos: Whether to ignore the EOS token and continue generating
            tokens after the EOS token is generated.
156
        max_tokens: Maximum number of tokens to generate per output sequence.
157
158
        min_tokens: Minimum number of tokens to generate per output sequence
            before EOS or stop_token_ids can be generated
159
        logprobs: Number of log probabilities to return per output token.
160
161
162
163
164
165
            When set to None, no probability is returned. If set to a non-None
            value, the result includes the log probabilities of the specified
            number of most likely tokens, as well as the chosen tokens.
            Note that the implementation follows the OpenAI API: The API will
            always return the log probability of the sampled token, so there
            may be up to `logprobs+1` elements in the response.
166
        prompt_logprobs: Number of log probabilities to return per prompt token.
167
        detokenize: Whether to detokenize the output. Defaults to True.
168
        skip_special_tokens: Whether to skip special tokens in the output.
169
170
        spaces_between_special_tokens: Whether to add spaces between special
            tokens in the output.  Defaults to True.
171
        logits_processors: List of functions that modify logits based on
172
173
            previously generated tokens, and optionally prompt tokens as
            a first argument.
174
175
176
        truncate_prompt_tokens: If set to an integer k, will use only the last k
            tokens from the prompt (i.e., left truncation). Defaults to None
            (i.e., no truncation).
177
178
179
180
181
182
183
        guided_decoding: If provided, the engine will construct a guided
            decoding logits processor from these parameters. Defaults to None.
        logit_bias: If provided, the engine will construct a logits processor
            that applies these logit biases. Defaults to None.
        allowed_token_ids: If provided, the engine will construct a logits
            processor which only retains scores for the given token ids.
            Defaults to None.
184
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    n: int = 1
    best_of: Optional[int] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    repetition_penalty: float = 1.0
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1
    min_p: float = 0.0
    seed: Optional[int] = None
    use_beam_search: bool = False
    length_penalty: float = 1.0
    early_stopping: Union[bool, str] = False
    stop: Optional[Union[str, List[str]]] = None
    stop_token_ids: Optional[List[int]] = None
    ignore_eos: bool = False
    max_tokens: Optional[int] = 16
    min_tokens: int = 0
    logprobs: Optional[int] = None
    prompt_logprobs: Optional[int] = None
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    # Optional[List[LogitsProcessor]] type. We use Any here because
    # Optional[List[LogitsProcessor]] type is not supported by msgspec.
    logits_processors: Optional[Any] = None
    include_stop_str_in_output: bool = False
    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
217
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
218
219
220
221
222
223

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
    _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)

224
225
226
227
228
    # Fields used to construct logits processors
    guided_decoding: Optional[GuidedDecodingParams] = None
    logit_bias: Optional[Dict[int, float]] = None
    allowed_token_ids: Optional[List[int]] = None

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    @staticmethod
    def from_optional(
        n: Optional[int] = 1,
        best_of: Optional[int] = None,
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
        repetition_penalty: Optional[float] = 1.0,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        seed: Optional[int] = None,
        use_beam_search: bool = False,
        length_penalty: float = 1.0,
        early_stopping: Union[bool, str] = False,
        stop: Optional[Union[str, List[str]]] = None,
        stop_token_ids: Optional[List[int]] = None,
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
        max_tokens: Optional[int] = 16,
        min_tokens: int = 0,
        logprobs: Optional[int] = None,
        prompt_logprobs: Optional[int] = None,
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
        logits_processors: Optional[List[LogitsProcessor]] = None,
        truncate_prompt_tokens: Optional[Annotated[int,
                                                   msgspec.Meta(ge=1)]] = None,
258
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
259
260
261
        guided_decoding: Optional[GuidedDecodingParams] = None,
        logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
        allowed_token_ids: Optional[List[int]] = None,
262
    ) -> "SamplingParams":
263
264
265
266
267
268
        if logit_bias is not None:
            logit_bias = {
                int(token): bias
                for token, bias in logit_bias.items()
            }

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        return SamplingParams(
            n=1 if n is None else n,
            best_of=best_of,
            presence_penalty=0.0
            if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0
            if frequency_penalty is None else frequency_penalty,
            repetition_penalty=1.0
            if repetition_penalty is None else repetition_penalty,
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            use_beam_search=use_beam_search,
            length_penalty=length_penalty,
            early_stopping=early_stopping,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
299
            output_kind=output_kind,
300
301
302
            guided_decoding=guided_decoding,
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
303
304
        )

305
306
307
    def __post_init__(self) -> None:
        self.best_of = self.best_of or self.n
        if 0 < self.temperature < _MAX_TEMP:
308
309
310
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
311
312
313
                self.temperature, _MAX_TEMP, _MAX_TEMP)
            self.temperature = max(self.temperature, _MAX_TEMP)
        if self.seed == -1:
314
315
            self.seed = None
        else:
316
317
            self.seed = self.seed
        if self.stop is None:
318
            self.stop = []
319
320
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
321
        else:
322
323
            self.stop = list(self.stop)
        if self.stop_token_ids is None:
324
325
            self.stop_token_ids = []
        else:
326
327
328
329
330
            self.stop_token_ids = list(self.stop_token_ids)
        self.logprobs = 1 if self.logprobs is True else self.logprobs
        self.prompt_logprobs = (1 if self.prompt_logprobs is True else
                                self.prompt_logprobs)

331
332
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
333
        if self.stop and not self.include_stop_str_in_output:
334
335
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

336
337
        self._verify_args()
        if self.use_beam_search:
338
339
340
341
            if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
                raise ValueError(
                    "Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ."  # noqa
                )
342
            self._verify_beam_search()
343
344
345
346
        else:
            self._verify_non_beam_search()
            if self.temperature < _SAMPLING_EPS:
                # Zero temperature means greedy sampling.
347
348
                self.top_p = 1.0
                self.top_k = -1
Roy's avatar
Roy committed
349
                self.min_p = 0.0
350
                self._verify_greedy_sampling()
351
        # eos_token_id is added to this by the engine
352
        self._all_stop_token_ids = set(self.stop_token_ids)
353
354

    def _verify_args(self) -> None:
355
356
357
        if not isinstance(self.n, int):
            raise ValueError(f"n must be an int, but is of "
                             f"type {type(self.n)}")
358
359
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
360
361
362
        if not isinstance(self.best_of, int):
            raise ValueError(f'best_of must be an int, but is of '
                             f'type {type(self.best_of)}')
363
364
365
        if self.best_of < self.n:
            raise ValueError(f"best_of must be greater than or equal to n, "
                             f"got n={self.n} and best_of={self.best_of}.")
366
367
368
369
370
371
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
ljss's avatar
ljss committed
372
373
374
        if not 0.0 < self.repetition_penalty <= 2.0:
            raise ValueError("repetition_penalty must be in (0, 2], got "
                             f"{self.repetition_penalty}.")
375
376
377
378
379
380
381
382
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
383
384
385
        if not isinstance(self.top_k, int):
            raise TypeError(
                f"top_k must be an integer, got {type(self.top_k).__name__}")
Roy's avatar
Roy committed
386
387
388
        if not 0.0 <= self.min_p <= 1.0:
            raise ValueError("min_p must be in [0, 1], got "
                             f"{self.min_p}.")
389
        if self.max_tokens is not None and self.max_tokens < 1:
390
391
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
392
393
394
395
396
397
398
        if self.min_tokens < 0:
            raise ValueError(f"min_tokens must be greater than or equal to 0, "
                             f"got {self.min_tokens}.")
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
                f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
Zhuohan Li's avatar
Zhuohan Li committed
399
        if self.logprobs is not None and self.logprobs < 0:
400
401
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")
402
403
404
        if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
            raise ValueError(f"prompt_logprobs must be non-negative, got "
                             f"{self.prompt_logprobs}.")
405
406
407
408
        if (self.truncate_prompt_tokens is not None
                and self.truncate_prompt_tokens < 1):
            raise ValueError(f"truncate_prompt_tokens must be >= 1, "
                             f"got {self.truncate_prompt_tokens}")
409
        assert isinstance(self.stop, list)
410
411
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
412
413
414
415
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
                "Set detokenize=True to use stop.")
416
417
418
        if self.best_of != self.n and self.output_kind == (
                RequestOutputKind.DELTA):
            raise ValueError("best_of must equal n to use output_kind=DELTA")
419

420
    def _verify_beam_search(self) -> None:
421
422
423
        if self.best_of == 1:
            raise ValueError("best_of must be greater than 1 when using beam "
                             f"search. Got {self.best_of}.")
424
        if self.temperature > _SAMPLING_EPS:
425
            raise ValueError("temperature must be 0 when using beam search.")
426
        if self.top_p < 1.0 - _SAMPLING_EPS:
427
428
429
            raise ValueError("top_p must be 1 when using beam search.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using beam search.")
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        if self.early_stopping not in [True, False, "never"]:
            raise ValueError(
                f"early_stopping must be True, False, or 'never', "
                f"got {self.early_stopping}.")

    def _verify_non_beam_search(self) -> None:
        if self.early_stopping is not False:
            raise ValueError("early_stopping is not effective and must be "
                             "False when not using beam search.")
        if (self.length_penalty < 1.0 - _SAMPLING_EPS
                or self.length_penalty > 1.0 + _SAMPLING_EPS):
            raise ValueError(
                "length_penalty is not effective and must be the "
                "default value of 1.0 when not using beam search.")
444
445

    def _verify_greedy_sampling(self) -> None:
446
        assert isinstance(self.best_of, int)
447
448
449
        if self.best_of > 1:
            raise ValueError("best_of must be 1 when using greedy sampling."
                             f"Got {self.best_of}.")
450

451
    def update_from_generation_config(
452
453
454
            self,
            generation_config: Dict[str, Any],
            model_eos_token_id: Optional[int] = None) -> None:
455
        """Update if there are non-default values from generation_config"""
456
457
458
459

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
460
            self._all_stop_token_ids.add(model_eos_token_id)
461

462
        # Update eos_token_id for generation
463
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
464
            # it can be either int or list of int
465
466
467
468
469
470
471
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
472
                self._all_stop_token_ids.update(eos_ids)
473
474
475
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
476

477
478
479
480
481
482
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.use_beam_search:
            return SamplingType.BEAM
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
483
484
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
485
486
        return SamplingType.RANDOM

487
488
489
490
    @property
    def all_stop_token_ids(self) -> Set[int]:
        return self._all_stop_token_ids

491
492
493
494
495
496
497
498
499
500
501
502
503
504
    def clone(self) -> "SamplingParams":
        """Deep copy excluding LogitsProcessor objects.

        LogitsProcessor objects are excluded because they may contain an
        arbitrary, nontrivial amount of data.
        See https://github.com/vllm-project/vllm/issues/3087
        """

        logit_processor_refs = None if self.logits_processors is None else {
            id(lp): lp
            for lp in self.logits_processors
        }
        return copy.deepcopy(self, memo=logit_processor_refs)

505
    def __repr__(self) -> str:
506
507
508
509
510
511
512
513
514
515
        return (
            f"SamplingParams(n={self.n}, "
            f"best_of={self.best_of}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
516
            f"seed={self.seed}, "
517
518
519
520
521
522
523
524
            f"use_beam_search={self.use_beam_search}, "
            f"length_penalty={self.length_penalty}, "
            f"early_stopping={self.early_stopping}, "
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
525
            f"min_tokens={self.min_tokens}, "
526
527
528
529
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
530
            f"{self.spaces_between_special_tokens}, "
531
532
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
            f"guided_decoding={self.guided_decoding}")