sampling_params.py 21.8 KB
Newer Older
1
"""Sampling parameters for text generation."""
2
import copy
3
from dataclasses import dataclass
4
from enum import Enum, IntEnum
5
from functools import cached_property
6
from typing import Any, Dict, List, Optional, Set, Union
7

8
import msgspec
9
from pydantic import BaseModel
10
from typing_extensions import Annotated
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from vllm.logger import init_logger
13
from vllm.logits_process import LogitsProcessor
14
15
16

logger = init_logger(__name__)

17
_SAMPLING_EPS = 1e-5
18
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20

21
22
23
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
24
    RANDOM_SEED = 2
25
26


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
    """One of these fields will be used to build a logit processor."""
    json: Optional[Union[str, Dict]] = None
    regex: Optional[str] = None
    choice: Optional[List[str]] = None
    grammar: Optional[str] = None
    json_object: Optional[bool] = None
    """These are other options that can be set"""
    backend: Optional[str] = None
    whitespace_pattern: Optional[str] = None

    @staticmethod
    def from_optional(
42
        json: Optional[Union[Dict, BaseModel, str]] = None,
43
44
45
46
47
48
        regex: Optional[str] = None,
        choice: Optional[List[str]] = None,
        grammar: Optional[str] = None,
        json_object: Optional[bool] = None,
        backend: Optional[str] = None,
        whitespace_pattern: Optional[str] = None,
49
50
51
52
    ) -> Optional["GuidedDecodingParams"]:
        if all(arg is None
               for arg in (json, regex, choice, grammar, json_object)):
            return None
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        # Extract json schemas from pydantic models
        if isinstance(json, (BaseModel, type(BaseModel))):
            json = json.model_json_schema()
        return GuidedDecodingParams(
            json=json,
            regex=regex,
            choice=choice,
            grammar=grammar,
            json_object=json_object,
            backend=backend,
            whitespace_pattern=whitespace_pattern,
        )

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
        guide_count = sum([
            self.json is not None, self.regex is not None, self.choice
            is not None, self.grammar is not None, self.json_object is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding but multiple are "
                f"specified: {self.__dict__}")


78
79
80
81
82
83
84
85
86
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
    # Do not return intermediate RequestOuputs
    FINAL_ONLY = 2


87
88
89
90
91
class SamplingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        # required for @cached_property.
        dict=True):  # type: ignore[call-arg]
92
93
94
95
96
97
98
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.

    Args:
99
100
101
        n: Number of output sequences to return for the given prompt.
        best_of: Number of output sequences that are generated from the prompt.
            From these `best_of` sequences, the top `n` sequences are returned.
102
103
            `best_of` must be greater than or equal to `n`. By default,
            `best_of` is set to `n`.
104
105
106
107
108
109
110
111
        presence_penalty: Float that penalizes new tokens based on whether they
            appear in the generated text so far. Values > 0 encourage the model
            to use new tokens, while values < 0 encourage the model to repeat
            tokens.
        frequency_penalty: Float that penalizes new tokens based on their
            frequency in the generated text so far. Values > 0 encourage the
            model to use new tokens, while values < 0 encourage the model to
            repeat tokens.
ljss's avatar
ljss committed
112
        repetition_penalty: Float that penalizes new tokens based on whether
113
114
115
            they appear in the prompt and the generated text so far. Values > 1
            encourage the model to use new tokens, while values < 1 encourage
            the model to repeat tokens.
116
117
118
119
120
121
122
        temperature: Float that controls the randomness of the sampling. Lower
            values make the model more deterministic, while higher values make
            the model more random. Zero means greedy sampling.
        top_p: Float that controls the cumulative probability of the top tokens
            to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
        top_k: Integer that controls the number of top tokens to consider. Set
            to -1 to consider all tokens.
Roy's avatar
Roy committed
123
124
125
        min_p: Float that represents the minimum probability for a token to be
            considered, relative to the probability of the most likely token.
            Must be in [0, 1]. Set to 0 to disable this.
Nick Hill's avatar
Nick Hill committed
126
        seed: Random seed to use for the generation.
127
128
        stop: List of strings that stop the generation when they are generated.
            The returned output will not contain the stop strings.
129
130
        stop_token_ids: List of tokens that stop the generation when they are
            generated. The returned output will contain the stop tokens unless
131
            the stop tokens are special tokens.
132
133
134
135
        bad_words: List of words that are not allowed to be generated.
            More precisely, only the last token of a corresponding
            token sequence is not allowed when the next generated token
            can complete the sequence.
136
137
        include_stop_str_in_output: Whether to include the stop strings in
            output text. Defaults to False.
138
139
        ignore_eos: Whether to ignore the EOS token and continue generating
            tokens after the EOS token is generated.
140
        max_tokens: Maximum number of tokens to generate per output sequence.
141
142
        min_tokens: Minimum number of tokens to generate per output sequence
            before EOS or stop_token_ids can be generated
143
        logprobs: Number of log probabilities to return per output token.
144
145
146
147
148
149
            When set to None, no probability is returned. If set to a non-None
            value, the result includes the log probabilities of the specified
            number of most likely tokens, as well as the chosen tokens.
            Note that the implementation follows the OpenAI API: The API will
            always return the log probability of the sampled token, so there
            may be up to `logprobs+1` elements in the response.
150
        prompt_logprobs: Number of log probabilities to return per prompt token.
151
        detokenize: Whether to detokenize the output. Defaults to True.
152
        skip_special_tokens: Whether to skip special tokens in the output.
153
154
        spaces_between_special_tokens: Whether to add spaces between special
            tokens in the output.  Defaults to True.
155
        logits_processors: List of functions that modify logits based on
156
157
            previously generated tokens, and optionally prompt tokens as
            a first argument.
158
159
160
        truncate_prompt_tokens: If set to an integer k, will use only the last k
            tokens from the prompt (i.e., left truncation). Defaults to None
            (i.e., no truncation).
161
162
163
164
165
166
167
        guided_decoding: If provided, the engine will construct a guided
            decoding logits processor from these parameters. Defaults to None.
        logit_bias: If provided, the engine will construct a logits processor
            that applies these logit biases. Defaults to None.
        allowed_token_ids: If provided, the engine will construct a logits
            processor which only retains scores for the given token ids.
            Defaults to None.
168
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
169

170
171
    n: int = 1
    best_of: Optional[int] = None
172
    _real_n: Optional[int] = None
173
174
175
176
177
178
179
180
181
182
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    repetition_penalty: float = 1.0
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1
    min_p: float = 0.0
    seed: Optional[int] = None
    stop: Optional[Union[str, List[str]]] = None
    stop_token_ids: Optional[List[int]] = None
183
    bad_words: Optional[List[str]] = None
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    ignore_eos: bool = False
    max_tokens: Optional[int] = 16
    min_tokens: int = 0
    logprobs: Optional[int] = None
    prompt_logprobs: Optional[int] = None
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    # Optional[List[LogitsProcessor]] type. We use Any here because
    # Optional[List[LogitsProcessor]] type is not supported by msgspec.
    logits_processors: Optional[Any] = None
    include_stop_str_in_output: bool = False
    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
200
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
201
202
203
204
205
206

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
    _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)

207
208
209
210
211
    # Fields used to construct logits processors
    guided_decoding: Optional[GuidedDecodingParams] = None
    logit_bias: Optional[Dict[int, float]] = None
    allowed_token_ids: Optional[List[int]] = None

212
213
214
215
216
217
218
219
220
221
222
223
224
225
    @staticmethod
    def from_optional(
        n: Optional[int] = 1,
        best_of: Optional[int] = None,
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
        repetition_penalty: Optional[float] = 1.0,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        seed: Optional[int] = None,
        stop: Optional[Union[str, List[str]]] = None,
        stop_token_ids: Optional[List[int]] = None,
226
        bad_words: Optional[List[str]] = None,
227
228
229
230
231
232
233
234
235
236
237
238
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
        max_tokens: Optional[int] = 16,
        min_tokens: int = 0,
        logprobs: Optional[int] = None,
        prompt_logprobs: Optional[int] = None,
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
        logits_processors: Optional[List[LogitsProcessor]] = None,
        truncate_prompt_tokens: Optional[Annotated[int,
                                                   msgspec.Meta(ge=1)]] = None,
239
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
240
241
242
        guided_decoding: Optional[GuidedDecodingParams] = None,
        logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
        allowed_token_ids: Optional[List[int]] = None,
243
    ) -> "SamplingParams":
244
245
246
247
248
249
        if logit_bias is not None:
            logit_bias = {
                int(token): bias
                for token, bias in logit_bias.items()
            }

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        return SamplingParams(
            n=1 if n is None else n,
            best_of=best_of,
            presence_penalty=0.0
            if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0
            if frequency_penalty is None else frequency_penalty,
            repetition_penalty=1.0
            if repetition_penalty is None else repetition_penalty,
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
266
            bad_words=bad_words,
267
268
269
270
271
272
273
274
275
276
277
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
278
            output_kind=output_kind,
279
280
281
            guided_decoding=guided_decoding,
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
282
283
        )

284
    def __post_init__(self) -> None:
285
286
287
288
289
290
291
292
293
294
295
        # how we deal with `best_of``:
        # if `best_of`` is not set, we default to `n`;
        # if `best_of`` is set, we set `n`` to `best_of`,
        # and set `_real_n`` to the original `n`.
        # when we return the result, we will check
        # if we need to return `n` or `_real_n` results
        if self.best_of:
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
                    f"got n={self.n} and best_of={self.best_of}.")
296
297
298
            if not self._real_n:
                self._real_n = self.n
                self.n = self.best_of
299

300
        if 0 < self.temperature < _MAX_TEMP:
301
302
303
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
304
305
                self.temperature, _MAX_TEMP, _MAX_TEMP)
            self.temperature = max(self.temperature, _MAX_TEMP)
306

307
        if self.seed == -1:
308
309
            self.seed = None
        else:
310
            self.seed = self.seed
311

312
        if self.stop is None:
313
            self.stop = []
314
315
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
316
        else:
317
            self.stop = list(self.stop)
318

319
        if self.stop_token_ids is None:
320
321
            self.stop_token_ids = []
        else:
322
            self.stop_token_ids = list(self.stop_token_ids)
323
324
325
326
327
328

        if self.bad_words is None:
            self.bad_words = []
        else:
            self.bad_words = list(self.bad_words)

329
330
331
332
        self.logprobs = 1 if self.logprobs is True else self.logprobs
        self.prompt_logprobs = (1 if self.prompt_logprobs is True else
                                self.prompt_logprobs)

333
334
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
335
        if self.stop and not self.include_stop_str_in_output:
336
337
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

338
        self._verify_args()
339
340
341
342
343
344
345

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
            self.top_k = -1
            self.min_p = 0.0
            self._verify_greedy_sampling()
346
        # eos_token_id is added to this by the engine
347
        self._all_stop_token_ids = set(self.stop_token_ids)
348
349

    def _verify_args(self) -> None:
350
351
352
        if not isinstance(self.n, int):
            raise ValueError(f"n must be an int, but is of "
                             f"type {type(self.n)}")
353
354
355
356
357
358
359
360
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
ljss's avatar
ljss committed
361
362
363
        if not 0.0 < self.repetition_penalty <= 2.0:
            raise ValueError("repetition_penalty must be in (0, 2], got "
                             f"{self.repetition_penalty}.")
364
365
366
367
368
369
370
371
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
372
373
374
        if not isinstance(self.top_k, int):
            raise TypeError(
                f"top_k must be an integer, got {type(self.top_k).__name__}")
Roy's avatar
Roy committed
375
376
377
        if not 0.0 <= self.min_p <= 1.0:
            raise ValueError("min_p must be in [0, 1], got "
                             f"{self.min_p}.")
378
        if self.max_tokens is not None and self.max_tokens < 1:
379
380
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
381
382
383
384
385
386
387
        if self.min_tokens < 0:
            raise ValueError(f"min_tokens must be greater than or equal to 0, "
                             f"got {self.min_tokens}.")
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
                f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
Zhuohan Li's avatar
Zhuohan Li committed
388
        if self.logprobs is not None and self.logprobs < 0:
389
390
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")
391
392
393
        if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
            raise ValueError(f"prompt_logprobs must be non-negative, got "
                             f"{self.prompt_logprobs}.")
394
395
396
397
        if (self.truncate_prompt_tokens is not None
                and self.truncate_prompt_tokens < 1):
            raise ValueError(f"truncate_prompt_tokens must be >= 1, "
                             f"got {self.truncate_prompt_tokens}")
398
        assert isinstance(self.stop, list)
399
400
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
401
402
403
404
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
                "Set detokenize=True to use stop.")
405
        if self.best_of != self._real_n and self.output_kind == (
406
407
                RequestOutputKind.DELTA):
            raise ValueError("best_of must equal n to use output_kind=DELTA")
408
409

    def _verify_greedy_sampling(self) -> None:
410
411
412
        if self.n > 1:
            raise ValueError("n must be 1 when using greedy sampling, "
                             f"got {self.n}.")
413

414
    def update_from_generation_config(
415
416
417
            self,
            generation_config: Dict[str, Any],
            model_eos_token_id: Optional[int] = None) -> None:
418
        """Update if there are non-default values from generation_config"""
419
420
421
422

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
423
            self._all_stop_token_ids.add(model_eos_token_id)
424

425
        # Update eos_token_id for generation
426
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
427
            # it can be either int or list of int
428
429
430
431
432
433
434
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
435
                self._all_stop_token_ids.update(eos_ids)
436
437
438
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
439

440
441
442
443
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
444
445
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
446
447
        return SamplingType.RANDOM

448
449
450
451
    @property
    def all_stop_token_ids(self) -> Set[int]:
        return self._all_stop_token_ids

452
    def clone(self) -> "SamplingParams":
453
        """Deep copy, but maybe not the LogitsProcessor objects.
454

455
456
457
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
458
459
460
461
        See https://github.com/vllm-project/vllm/issues/3087
        """

        logit_processor_refs = None if self.logits_processors is None else {
462
            id(lp): lp.clone() if hasattr(lp, 'clone') else lp
463
464
465
466
            for lp in self.logits_processors
        }
        return copy.deepcopy(self, memo=logit_processor_refs)

467
    def __repr__(self) -> str:
468
469
470
471
472
473
474
475
476
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
477
            f"seed={self.seed}, "
478
479
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
480
            f"bad_words={self.bad_words}, "
481
482
483
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
484
            f"min_tokens={self.min_tokens}, "
485
486
487
488
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
489
            f"{self.spaces_between_special_tokens}, "
490
491
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
            f"guided_decoding={self.guided_decoding})")
492
493
494
495
496
497
498
499
500
501
502
503


class BeamSearchParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        # required for @cached_property.
        dict=True):  # type: ignore[call-arg]
    """Beam search parameters for text generation."""
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
504
    length_penalty: float = 1.0
505
    include_stop_str_in_output: bool = False