sampling_params.py 38.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import json as json_mod
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Any
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
import vllm.envs as envs
16
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
17
from vllm.exceptions import VLLMValidationError
18
from vllm.logger import init_logger
19
from vllm.tokenizers import TokenizerLike
20
from vllm.utils.mistral import is_mistral_tokenizer
21
from vllm.v1.serial_utils import PydanticMsgspecMixin
22
23
24

logger = init_logger(__name__)

25
_SAMPLING_EPS = 1e-5
26
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
27

28

29
30
31
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
32
    RANDOM_SEED = 2
33
34


35
36
# maybe make msgspec?
@dataclass
37
38
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
39
40
41
42
43
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
44
    # These are other options that can be set.
45
46
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
47
48
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
49

50
    _backend: str | None = field(default=None, init=False)
51
52
53
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
54
55
56

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
57
58
59
60
61
62
63
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
64
                self.structural_tag is not None,
65
66
            ]
        )
67
        if count > 1:
68
            raise ValueError(
69
                "You can only use one kind of structured outputs constraint "
70
71
                f"but multiple are specified: {self.__dict__}"
            )
72
73
74
75
76
        if count < 1:
            raise ValueError(
                "You must use one kind of structured outputs constraint "
                f"but none are specified: {self.__dict__}"
            )
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass
class RepetitionDetectionParams:
    """Parameters for detecting repetitive N-gram patterns in output tokens."""

    max_pattern_size: int = 0
    """Maximum size of N-gram pattern to detect for sequence repetition.
    Set to 0 to disable. Must be used together with min_count."""

    min_pattern_size: int = 0
    """Minimum N-gram pattern size to check for sequence repetition.
    If set to 0, it defaults to 1.
    Must be <= max_pattern_size."""

    min_count: int = 0
    """Minimum number of times an N-gram pattern must repeat to trigger
    detection. Must be >= 2. Example: 3 for detecting a phrase repeated
    3 times. Must be used together with max_pattern_size."""

    def __post_init__(self):
        if (
            self.max_pattern_size < 0
            or self.min_pattern_size < 0
            or self.min_pattern_size > self.max_pattern_size
        ):
            raise ValueError(
                "max_pattern_size, min_pattern_size must be >=0, "
                "with min_pattern_size <= max_pattern_size. "
                "Set both to 0 to disable repetitive pattern detection."
            )
        if self.max_pattern_size > 0 and self.min_count < 2:
            raise ValueError(
                "min_count must be >= 2 to detect repetitive patterns "
                "in engine output. If you do not wish to detect repetitive "
                "patterns, set max_pattern_size to 0."
            )


147
148
149
150
151
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
152
    # Do not return intermediate RequestOutput
153
154
155
    FINAL_ONLY = 2


156
157
158
159
def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
    return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken


160
161
162
163
def _get_llg_tokenizer(tokenizer: TokenizerLike) -> Any:
    return tokenizer.llg_tokenizer if is_mistral_tokenizer(tokenizer) else None


164
class SamplingParams(
165
    PydanticMsgspecMixin,
166
167
168
169
170
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
171
172
173
174
175
176
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
177

178
    n: int = 1
179
180
    """Number of outputs to return for the given prompt request.

181
182
183
    The maximum allowed value is controlled by the ``VLLM_MAX_N_SEQUENCES``
    environment variable (default: 16384).

184
185
186
187
188
    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
189
    presence_penalty: float = 0.0
190
191
192
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
193
    frequency_penalty: float = 0.0
194
195
196
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
197
    repetition_penalty: float = 1.0
198
199
200
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
201
    temperature: float = 1.0
202
203
204
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
205
    top_p: float = 1.0
206
207
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
208
    top_k: int = 0
209
210
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
211
    min_p: float = 0.0
212
213
214
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
215
    seed: int | None = None
216
    """Random seed to use for the generation."""
217
    stop: str | list[str] | None = None
218
219
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
220
    stop_token_ids: list[int] | None = None
221
222
223
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
224
    ignore_eos: bool = False
225
226
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
227
    max_tokens: int | None = 16
228
    """Maximum number of tokens to generate per output sequence."""
229
    min_tokens: int = 0
230
231
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
232
    logprobs: int | None = None
233
234
235
236
237
238
239
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
240
    prompt_logprobs: int | None = None
241
242
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
243
244
245
246
247
248
    logprob_token_ids: list[int] | None = None
    """Specific token IDs to return logprobs for. More efficient than
    logprobs=-1 when you only need logprobs for a small set of tokens.
    When set, logprobs for exactly these token IDs will be returned,
    in addition to the sampled token. This is useful for scoring tasks
    where you want to compare probabilities of specific label tokens."""
249
250
251
252
253
254
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
255
256
257
258
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
259
    """Whether to detokenize the output."""
260
    skip_special_tokens: bool = True
261
    """Whether to skip special tokens in the output."""
262
    spaces_between_special_tokens: bool = True
263
    """Whether to add spaces between special tokens in the output."""
264
    include_stop_str_in_output: bool = False
265
    """Whether to include the stop strings in output text."""
266
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
267
268
269
270
271
272
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
273
274
275
276

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
277
    _eos_token_id: int | None = None
278
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
279

280
    # Fields used to construct logits processors
281
    structured_outputs: StructuredOutputsParams | None = None
282
    """Parameters for configuring structured outputs."""
283
    logit_bias: dict[int, float] | None = None
284
285
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
286
    allowed_token_ids: list[int] | None = None
287
288
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
289
    extra_args: dict[str, Any] | None = None
290
291
292
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
293

294
    # Fields used for bad words
295
    bad_words: list[str] | None = None
296
297
298
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
299
    _bad_words_token_ids: list[list[int]] | None = None
300

301
    skip_reading_prefix_cache: bool | None = None
302
303
    thinking_token_budget: int | None = None
    """Maximum number of tokens allowed for thinking operations."""
304

305
306
307
308
309
310
311
312
    repetition_detection: RepetitionDetectionParams | None = None
    """Parameters for detecting repetitive N-gram patterns in output tokens.
    If such repetition is detected, generation will be ended early. LLMs can
    sometimes generate repetitive, unhelpful token patterns, stopping only
    when they hit the maximum output length (e.g. 'abcdabcdabcd...' or
    '\\emoji \\emoji \\emoji ...'). This feature can detect such behavior
    and terminate early, saving time and tokens."""

313
314
    @staticmethod
    def from_optional(
315
316
317
318
319
320
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
321
        top_k: int = 0,
322
        min_p: float = 0.0,
323
324
325
326
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
327
        thinking_token_budget: int | None = None,
328
329
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
330
        max_tokens: int | None = 16,
331
        min_tokens: int = 0,
332
333
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
334
335
336
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
337
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
338
339
340
341
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
342
        skip_clone: bool = False,
343
        repetition_detection: RepetitionDetectionParams | None = None,
344
    ) -> "SamplingParams":
345
        if logit_bias is not None:
346
347
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
348
            logit_bias = {
349
                int(token): min(100.0, max(-100.0, bias))
350
351
352
                for token, bias in logit_bias.items()
            }

353
354
        return SamplingParams(
            n=1 if n is None else n,
355
356
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
357
            repetition_penalty=1.0
358
359
            if repetition_penalty is None
            else repetition_penalty,
360
361
362
363
364
365
366
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
367
            bad_words=bad_words,
368
            thinking_token_budget=thinking_token_budget,
369
370
371
372
373
374
375
376
377
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
378
            output_kind=output_kind,
379
            structured_outputs=structured_outputs,
380
381
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
382
            extra_args=extra_args,
383
            skip_clone=skip_clone,
384
            repetition_detection=repetition_detection,
385
386
        )

387
388
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
389
390
391
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
392
393
394
395
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
396
            self.temperature = max(self.temperature, _MAX_TEMP)
397

398
        if self.seed == -1:
399
            self.seed = None
400

401
        if self.stop is None:
402
            self.stop = []
403
404
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
405

406
        if self.stop_token_ids is None:
407
            self.stop_token_ids = []
408
409
410
411

        if self.bad_words is None:
            self.bad_words = []

412
413
414
415
416
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
417

418
419
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
420
        if self.stop and not self.include_stop_str_in_output:
421
422
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

423
        self._verify_args()
424
425
426
427

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
428
            self.top_k = 0
429
430
            self.min_p = 0.0
            self._verify_greedy_sampling()
431

432
        # eos_token_id is added to this by the engine
433
        self._all_stop_token_ids.update(self.stop_token_ids)
434

435
436
437
438
439
440
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

441
    def _verify_args(self) -> None:
442
        if not isinstance(self.n, int):
443
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
444
445
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
446
447
448
449
450
451
452
        max_n = envs.VLLM_MAX_N_SEQUENCES
        if self.n > max_n:
            raise ValueError(
                f"n must be at most {max_n}, got {self.n}. "
                "To increase this limit, set the VLLM_MAX_N_SEQUENCES "
                "environment variable."
            )
453
        if not -2.0 <= self.presence_penalty <= 2.0:
454
455
456
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
457
        if not -2.0 <= self.frequency_penalty <= 2.0:
458
459
460
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
461
462
463
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
464
465
                f"{self.repetition_penalty}."
            )
466
        if self.temperature < 0.0:
467
468
469
470
            raise VLLMValidationError(
                f"temperature must be non-negative, got {self.temperature}.",
                parameter="temperature",
                value=self.temperature,
471
            )
472
        if not 0.0 < self.top_p <= 1.0:
473
474
475
476
477
            raise VLLMValidationError(
                f"top_p must be in (0, 1], got {self.top_p}.",
                parameter="top_p",
                value=self.top_p,
            )
478
479
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
480
481
482
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
483
484
        if not isinstance(self.top_k, int):
            raise TypeError(
485
486
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
487
        if not 0.0 <= self.min_p <= 1.0:
488
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
489
        if self.max_tokens is not None and self.max_tokens < 1:
490
491
492
493
494
            raise VLLMValidationError(
                f"max_tokens must be at least 1, got {self.max_tokens}.",
                parameter="max_tokens",
                value=self.max_tokens,
            )
495
        if self.min_tokens < 0:
496
497
498
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
499
500
501
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
502
503
504
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
505
506
507
508
            raise VLLMValidationError(
                f"logprobs must be non-negative or -1, got {self.logprobs}.",
                parameter="logprobs",
                value=self.logprobs,
509
510
511
512
513
514
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
515
            raise VLLMValidationError(
516
                f"prompt_logprobs must be non-negative or -1, got "
517
518
519
                f"{self.prompt_logprobs}.",
                parameter="prompt_logprobs",
                value=self.prompt_logprobs,
520
            )
521
522
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
523
524
525
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
526
        assert isinstance(self.stop, list)
527
528
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
529
530
531
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
532
533
                "Set detokenize=True to use stop."
            )
534
535

    def _verify_greedy_sampling(self) -> None:
536
        if self.n > 1:
537
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
538

539
    def update_from_generation_config(
540
541
        self,
        generation_config: dict[str, Any],
542
        eos_token_id: int | None = None,
543
    ) -> None:
544
        """Update if there are non-default values from generation_config"""
545
546
        if not self.ignore_eos:
            self._eos_token_id = eos_token_id
547

548
        if eos_token_id is not None:
549
550
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
551
            self._all_stop_token_ids.add(eos_token_id)
552

553
        # Update eos_token_id for generation
554
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
555
            # it can be either int or list of int
556
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
557
            if eos_token_id is not None:
558
559
560
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
561
                eos_ids.discard(eos_token_id)
562
            if eos_ids:
563
                self._all_stop_token_ids.update(eos_ids)
564
                if not self.ignore_eos:
565
                    assert self.stop_token_ids is not None
566
567
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
568

569
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
570
        if not self.bad_words:
571
            return
572
        self._bad_words_token_ids = []
573
574
575
576
577
578
579
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
580
581
582
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
583
584
585
586

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
587
588
589
590
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
591
592
593
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
594
595
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
596
597
598
599
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
600
            raise VLLMValidationError(
601
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
602
603
604
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
605
606
607
                f" 0 <= token_id <= {tokenizer.max_token_id}.",
                parameter="bad_words",
                value=self.bad_words,
608
            )
609

610
611
612
613
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
614
615
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
616
617
        return SamplingType.RANDOM

618
619
620
621
    @property
    def eos_token_id(self) -> int | None:
        return self._eos_token_id

622
    @property
623
    def all_stop_token_ids(self) -> set[int]:
624
625
        return self._all_stop_token_ids

626
    @property
627
    def bad_words_token_ids(self) -> list[list[int]] | None:
628
629
630
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

631
    def clone(self) -> "SamplingParams":
632
        """If skip_clone is True, uses shallow copy instead of deep copy."""
633
634
635
        if self.skip_clone:
            return copy.copy(self)

636
        return copy.deepcopy(self)
637

638
639
640
641
642
643
644
645
646
    def verify(
        self,
        model_config: ModelConfig,
        speculative_config: SpeculativeConfig | None,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        self._validate_logprobs(model_config)
        self._validate_logit_bias(model_config)
647
        self._validate_logits_processors(model_config)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        self._validate_allowed_token_ids(tokenizer)
        self._validate_spec_decode(speculative_config)
        self._validate_structured_outputs(structured_outputs_config, tokenizer)

    def _validate_logprobs(self, model_config: ModelConfig) -> None:
        max_logprobs = model_config.max_logprobs
        if max_logprobs == -1:
            max_logprobs = model_config.get_vocab_size()

        # Validate sample logprobs.
        if num_logprobs := self.logprobs:
            if num_logprobs == -1:
                num_logprobs = model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="logprobs",
                    value=num_logprobs,
                )

        # Validate prompt logprobs.
        if num_prompt_logprobs := self.prompt_logprobs:
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="prompt_logprobs",
                    value=num_prompt_logprobs,
                )

    def _validate_logit_bias(self, model_config: ModelConfig) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not self.logit_bias:
            return

        vocab_size = model_config.get_vocab_size()
        invalid_token_ids = [
            token_id
            for token_id in self.logit_bias
            if token_id < 0 or token_id >= vocab_size
        ]

        if invalid_token_ids:
            raise VLLMValidationError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
                parameter="logit_bias",
                value=invalid_token_ids,
            )

701
702
703
704
705
706
707
    def _validate_logits_processors(self, model_config: ModelConfig) -> None:
        from vllm.v1.sample.logits_processor import (
            validate_logits_processors_parameters,
        )

        validate_logits_processors_parameters(model_config.logits_processors, self)

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
    def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
        allowed_token_ids = self.allowed_token_ids
        if allowed_token_ids is None:
            return

        if len(allowed_token_ids) == 0:
            raise VLLMValidationError(
                "allowed_token_ids is not None and empty!",
                parameter="allowed_token_ids",
                value=allowed_token_ids,
            )

        if tokenizer is not None:
            vocab_size = len(tokenizer)
            invalid_token_ids = [
                token_id
                for token_id in allowed_token_ids
                if token_id < 0 or token_id >= vocab_size
            ]
            if invalid_token_ids:
                raise VLLMValidationError(
                    "allowed_token_ids contains out-of-vocab token id!",
                    parameter="allowed_token_ids",
                    value=invalid_token_ids,
                )

    def _validate_spec_decode(
        self,
        speculative_config: SpeculativeConfig | None,
    ) -> None:
        if speculative_config is None:
            return

        # Some sampling parameters are not yet compatible with spec decoding.
742
        if self.min_p > _SAMPLING_EPS or self.logit_bias:
743
            raise ValueError(
744
                "The min_p and logit_bias sampling parameters "
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
                "are not yet supported with speculative decoding."
            )

    def _validate_structured_outputs(
        self,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        if structured_outputs_config is None or self.structured_outputs is None:
            return

        if tokenizer is None:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

        backend = structured_outputs_config.backend
        if _backend := self.structured_outputs._backend:
            # Request-level backend selection is not supported.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_backend_was_auto` field set in the params.
            if backend != _backend and not (
                backend == "auto" and self.structured_outputs._backend_was_auto
            ):
                raise ValueError(
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request."
                )
        else:
            self.structured_outputs._backend = backend

        # Request content validation
        if (
            isinstance(self.structured_outputs.choice, list)
            and not self.structured_outputs.choice
        ):
            # It is invalid for choice to be an empty list
            raise ValueError(
                f"Choice '{self.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(self.structured_outputs.grammar, str)
            and self.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")

        from vllm.v1.structured_output.backend_guidance import (
            has_guidance_unsupported_json_features,
            validate_guidance_grammar,
        )
        from vllm.v1.structured_output.backend_lm_format_enforcer import (
            validate_structured_output_request_lm_format_enforcer,
        )
        from vllm.v1.structured_output.backend_outlines import (
            validate_structured_output_request_outlines,
        )
        from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar

        if backend.startswith("xgrammar"):
            # xgrammar with no fallback
            validate_xgrammar_grammar(self)
        elif backend.startswith("guidance"):
812
813
814
815
816
817
818
            if _is_non_tekken_mistral(tokenizer=tokenizer):
                raise ValueError(
                    "Non-tekken Mistral tokenizers are not supported for the 'guidance'"
                    " structured output backend. Please either use a more recent "
                    "Mistral model, the ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
819
820
821
822
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
823
824
825
826
            validate_guidance_grammar(
                self,
                tokenizer=_get_llg_tokenizer(tokenizer),
            )
827
828
829
830
831
        elif backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(self)
        elif backend == "lm-format-enforcer":
            # lm format enforcer backend
832
            if is_mistral_tokenizer(tokenizer):
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_structured_output_request_lm_format_enforcer(self)
        else:
            # NOTE: backend must be "auto" here, because we have
            # checked supported_backends above.
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
            try:
                validate_xgrammar_grammar(self)
                self.structured_outputs._backend = "xgrammar"
            except ValueError:
                # The request either failed validation
                # or includes some jsonschema feature(s) that
                # are not supported in xgrammar.

854
855
                skip_guidance = _is_non_tekken_mistral(tokenizer)

856
857
                # Check if schema has features unsupported by guidance
                so_params = self.structured_outputs
858
                if not skip_guidance and so_params.json:
859
                    if isinstance(so_params.json, str):
860
                        schema = json_mod.loads(so_params.json)
861
862
863
864
                    else:
                        schema = so_params.json
                    skip_guidance = has_guidance_unsupported_json_features(schema)

865
866
867
                if skip_guidance:
                    # Fall back to outlines if the tokenizer is non-tekken Mistral or
                    # the schema contains features unsupported by guidance
868
869
870
871
                    validate_structured_output_request_outlines(self)
                    self.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
872
873
874
875
                    validate_guidance_grammar(
                        self,
                        tokenizer=_get_llg_tokenizer(tokenizer),
                    )
876
877
878
879
880
881
882
883
                    self.structured_outputs._backend = "guidance"
            # Remember that this backend was set automatically
            self.structured_outputs._backend_was_auto = True

        # Run post-init validation. This is also important to ensure subsequent
        # roundtrip serialization/deserialization won't fail.
        self.structured_outputs.__post_init__()

884
    def __repr__(self) -> str:
885
886
887
888
889
890
891
892
893
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
894
            f"seed={self.seed}, "
895
896
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
897
            f"bad_words={self.bad_words}, "
898
            f"thinking_token_budget={self.thinking_token_budget}, "
899
900
901
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
902
            f"min_tokens={self.min_tokens}, "
903
904
905
906
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
907
            f"{self.spaces_between_special_tokens}, "
908
            f"structured_outputs={self.structured_outputs}, "
909
910
            f"extra_args={self.extra_args})"
        )
911

912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
    @staticmethod
    def for_sampler_warmup() -> "SamplingParams":
        """Set parameters to exercise all sampler logic."""
        return SamplingParams(
            temperature=0.9,
            top_p=0.9,
            top_k=50,
            min_p=0.1,
            frequency_penalty=0.5,
            presence_penalty=0.5,
            repetition_penalty=1.2,
            min_tokens=2,
            logit_bias={0: -1.0, 1: 0.5},
            _bad_words_token_ids=[[0], [1, 2]],
            logprobs=5,
            prompt_logprobs=1,
        )

930
931

class BeamSearchParams(
932
933
934
935
936
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
937
    """Beam search parameters for text generation."""
938

939
940
941
942
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
943
    length_penalty: float = 1.0
944
    include_stop_str_in_output: bool = False