sampling_params.py 34.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import json
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Annotated, Any
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
16
from vllm.exceptions import VLLMValidationError
17
from vllm.logger import init_logger
18
from vllm.tokenizers import TokenizerLike
19
from vllm.v1.serial_utils import PydanticMsgspecMixin
20
21
22

logger = init_logger(__name__)

23
_SAMPLING_EPS = 1e-5
24
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26

27
28
29
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
30
    RANDOM_SEED = 2
31
32


33
34
# maybe make msgspec?
@dataclass
35
36
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
37
38
39
40
41
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
42
    # These are other options that can be set.
43
44
45
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
46
47
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
48

49
    _backend: str | None = field(default=None, init=False)
50
51
52
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
53
54
55

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
56
57
58
59
60
61
62
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
63
                self.structural_tag is not None,
64
65
            ]
        )
66
        if count > 1:
67
            raise ValueError(
68
                "You can only use one kind of structured outputs constraint "
69
70
                f"but multiple are specified: {self.__dict__}"
            )
71
72
73
74
75
        if count < 1:
            raise ValueError(
                "You must use one kind of structured outputs constraint "
                f"but none are specified: {self.__dict__}"
            )
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

108

109
110
111
112
113
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
114
    # Do not return intermediate RequestOutput
115
116
117
    FINAL_ONLY = 2


118
class SamplingParams(
119
    PydanticMsgspecMixin,
120
121
122
123
124
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
125
126
127
128
129
130
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
131

132
    n: int = 1
133
134
135
136
137
138
139
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
140
    presence_penalty: float = 0.0
141
142
143
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
144
    frequency_penalty: float = 0.0
145
146
147
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
148
    repetition_penalty: float = 1.0
149
150
151
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
152
    temperature: float = 1.0
153
154
155
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
156
    top_p: float = 1.0
157
158
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
159
    top_k: int = 0
160
161
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
162
    min_p: float = 0.0
163
164
165
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
166
    seed: int | None = None
167
    """Random seed to use for the generation."""
168
    stop: str | list[str] | None = None
169
170
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
171
    stop_token_ids: list[int] | None = None
172
173
174
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
175
    ignore_eos: bool = False
176
177
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
178
    max_tokens: int | None = 16
179
    """Maximum number of tokens to generate per output sequence."""
180
    min_tokens: int = 0
181
182
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
183
    logprobs: int | None = None
184
185
186
187
188
189
190
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
191
    prompt_logprobs: int | None = None
192
193
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
194
195
196
197
198
199
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
200
201
202
203
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
204
    """Whether to detokenize the output."""
205
    skip_special_tokens: bool = True
206
    """Whether to skip special tokens in the output."""
207
    spaces_between_special_tokens: bool = True
208
    """Whether to add spaces between special tokens in the output."""
209
    include_stop_str_in_output: bool = False
210
    """Whether to include the stop strings in output text."""
211
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
212
213
214
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
215
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
216
217
218
219
220
221
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
222
223
224
225

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
226
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
227

228
    # Fields used to construct logits processors
229
    structured_outputs: StructuredOutputsParams | None = None
230
    """Parameters for configuring structured outputs."""
231
    logit_bias: dict[int, float] | None = None
232
233
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
234
    allowed_token_ids: list[int] | None = None
235
236
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
237
    extra_args: dict[str, Any] | None = None
238
239
240
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
241

242
    # Fields used for bad words
243
    bad_words: list[str] | None = None
244
245
246
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
247
    _bad_words_token_ids: list[list[int]] | None = None
248

249
    skip_reading_prefix_cache: bool | None = None
250

251
252
    @staticmethod
    def from_optional(
253
254
255
256
257
258
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
259
        top_k: int = 0,
260
        min_p: float = 0.0,
261
262
263
264
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
265
266
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
267
        max_tokens: int | None = 16,
268
        min_tokens: int = 0,
269
270
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
271
272
273
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
274
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
275
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
276
277
278
279
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
280
        skip_clone: bool = False,
281
    ) -> "SamplingParams":
282
        if logit_bias is not None:
283
284
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
285
            logit_bias = {
286
                int(token): min(100.0, max(-100.0, bias))
287
288
289
                for token, bias in logit_bias.items()
            }

290
291
        return SamplingParams(
            n=1 if n is None else n,
292
293
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
294
            repetition_penalty=1.0
295
296
            if repetition_penalty is None
            else repetition_penalty,
297
298
299
300
301
302
303
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
304
            bad_words=bad_words,
305
306
307
308
309
310
311
312
313
314
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            truncate_prompt_tokens=truncate_prompt_tokens,
315
            output_kind=output_kind,
316
            structured_outputs=structured_outputs,
317
318
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
319
            extra_args=extra_args,
320
            skip_clone=skip_clone,
321
322
        )

323
324
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
325
326
327
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
328
329
330
331
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
332
            self.temperature = max(self.temperature, _MAX_TEMP)
333

334
        if self.seed == -1:
335
            self.seed = None
336

337
        if self.stop is None:
338
            self.stop = []
339
340
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
341

342
        if self.stop_token_ids is None:
343
            self.stop_token_ids = []
344
345
346
347

        if self.bad_words is None:
            self.bad_words = []

348
349
350
351
352
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
353

354
355
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
356
        if self.stop and not self.include_stop_str_in_output:
357
358
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

359
        self._verify_args()
360
361
362
363

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
364
            self.top_k = 0
365
366
            self.min_p = 0.0
            self._verify_greedy_sampling()
367

368
        # eos_token_id is added to this by the engine
369
        self._all_stop_token_ids.update(self.stop_token_ids)
370

371
372
373
374
375
376
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

377
    def _verify_args(self) -> None:
378
        if not isinstance(self.n, int):
379
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
380
381
382
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
383
384
385
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
386
        if not -2.0 <= self.frequency_penalty <= 2.0:
387
388
389
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
390
391
392
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
393
394
                f"{self.repetition_penalty}."
            )
395
        if self.temperature < 0.0:
396
397
398
399
            raise VLLMValidationError(
                f"temperature must be non-negative, got {self.temperature}.",
                parameter="temperature",
                value=self.temperature,
400
            )
401
        if not 0.0 < self.top_p <= 1.0:
402
403
404
405
406
            raise VLLMValidationError(
                f"top_p must be in (0, 1], got {self.top_p}.",
                parameter="top_p",
                value=self.top_p,
            )
407
408
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
409
410
411
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
412
413
        if not isinstance(self.top_k, int):
            raise TypeError(
414
415
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
416
        if not 0.0 <= self.min_p <= 1.0:
417
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
418
        if self.max_tokens is not None and self.max_tokens < 1:
419
420
421
422
423
            raise VLLMValidationError(
                f"max_tokens must be at least 1, got {self.max_tokens}.",
                parameter="max_tokens",
                value=self.max_tokens,
            )
424
        if self.min_tokens < 0:
425
426
427
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
428
429
430
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
431
432
433
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
434
435
436
437
            raise VLLMValidationError(
                f"logprobs must be non-negative or -1, got {self.logprobs}.",
                parameter="logprobs",
                value=self.logprobs,
438
439
440
441
442
443
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
444
            raise VLLMValidationError(
445
                f"prompt_logprobs must be non-negative or -1, got "
446
447
448
                f"{self.prompt_logprobs}.",
                parameter="prompt_logprobs",
                value=self.prompt_logprobs,
449
450
451
452
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
453
            raise VLLMValidationError(
454
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
455
456
457
                f"got {self.truncate_prompt_tokens}",
                parameter="truncate_prompt_tokens",
                value=self.truncate_prompt_tokens,
458
            )
459
460
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
461
462
463
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
464
        assert isinstance(self.stop, list)
465
466
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
467
468
469
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
470
471
                "Set detokenize=True to use stop."
            )
472
473

    def _verify_greedy_sampling(self) -> None:
474
        if self.n > 1:
475
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
476

477
    def update_from_generation_config(
478
479
        self,
        generation_config: dict[str, Any],
480
        model_eos_token_id: int | None = None,
481
    ) -> None:
482
        """Update if there are non-default values from generation_config"""
483
484
485
486

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
487
            self._all_stop_token_ids.add(model_eos_token_id)
488

489
        # Update eos_token_id for generation
490
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
491
            # it can be either int or list of int
492
493
494
495
496
497
498
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
499
                self._all_stop_token_ids.update(eos_ids)
500
501
502
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
503

504
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
505
        if not self.bad_words:
506
            return
507
        self._bad_words_token_ids = []
508
509
510
511
512
513
514
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
515
516
517
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
518
519
520
521

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
522
523
524
525
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
526
527
528
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
529
530
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
531
532
533
534
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
535
            raise VLLMValidationError(
536
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
537
538
539
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
540
541
542
                f" 0 <= token_id <= {tokenizer.max_token_id}.",
                parameter="bad_words",
                value=self.bad_words,
543
            )
544

545
546
547
548
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
549
550
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
551
552
        return SamplingType.RANDOM

553
    @property
554
    def all_stop_token_ids(self) -> set[int]:
555
556
        return self._all_stop_token_ids

557
    @property
558
    def bad_words_token_ids(self) -> list[list[int]] | None:
559
560
561
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

562
    def clone(self) -> "SamplingParams":
563
        """If skip_clone is True, uses shallow copy instead of deep copy."""
564
565
566
        if self.skip_clone:
            return copy.copy(self)

567
        return copy.deepcopy(self)
568

569
570
571
572
573
574
575
576
577
    def verify(
        self,
        model_config: ModelConfig,
        speculative_config: SpeculativeConfig | None,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        self._validate_logprobs(model_config)
        self._validate_logit_bias(model_config)
578
        self._validate_logits_processors(model_config)
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        self._validate_allowed_token_ids(tokenizer)
        self._validate_spec_decode(speculative_config)
        self._validate_structured_outputs(structured_outputs_config, tokenizer)

    def _validate_logprobs(self, model_config: ModelConfig) -> None:
        max_logprobs = model_config.max_logprobs
        if max_logprobs == -1:
            max_logprobs = model_config.get_vocab_size()

        # Validate sample logprobs.
        if num_logprobs := self.logprobs:
            if num_logprobs == -1:
                num_logprobs = model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="logprobs",
                    value=num_logprobs,
                )

        # Validate prompt logprobs.
        if num_prompt_logprobs := self.prompt_logprobs:
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="prompt_logprobs",
                    value=num_prompt_logprobs,
                )

    def _validate_logit_bias(self, model_config: ModelConfig) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not self.logit_bias:
            return

        vocab_size = model_config.get_vocab_size()
        invalid_token_ids = [
            token_id
            for token_id in self.logit_bias
            if token_id < 0 or token_id >= vocab_size
        ]

        if invalid_token_ids:
            raise VLLMValidationError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
                parameter="logit_bias",
                value=invalid_token_ids,
            )

632
633
634
635
636
637
638
    def _validate_logits_processors(self, model_config: ModelConfig) -> None:
        from vllm.v1.sample.logits_processor import (
            validate_logits_processors_parameters,
        )

        validate_logits_processors_parameters(model_config.logits_processors, self)

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
    def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
        allowed_token_ids = self.allowed_token_ids
        if allowed_token_ids is None:
            return

        if len(allowed_token_ids) == 0:
            raise VLLMValidationError(
                "allowed_token_ids is not None and empty!",
                parameter="allowed_token_ids",
                value=allowed_token_ids,
            )

        if tokenizer is not None:
            vocab_size = len(tokenizer)
            invalid_token_ids = [
                token_id
                for token_id in allowed_token_ids
                if token_id < 0 or token_id >= vocab_size
            ]
            if invalid_token_ids:
                raise VLLMValidationError(
                    "allowed_token_ids contains out-of-vocab token id!",
                    parameter="allowed_token_ids",
                    value=invalid_token_ids,
                )

    def _validate_spec_decode(
        self,
        speculative_config: SpeculativeConfig | None,
    ) -> None:
        if speculative_config is None:
            return

        # Some sampling parameters are not yet compatible with spec decoding.
        if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
            raise ValueError(
                "The min_tokens, min_p, and logit_bias sampling parameters "
                "are not yet supported with speculative decoding."
            )

    def _validate_structured_outputs(
        self,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        if structured_outputs_config is None or self.structured_outputs is None:
            return

        if tokenizer is None:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

        backend = structured_outputs_config.backend
        if _backend := self.structured_outputs._backend:
            # Request-level backend selection is not supported.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_backend_was_auto` field set in the params.
            if backend != _backend and not (
                backend == "auto" and self.structured_outputs._backend_was_auto
            ):
                raise ValueError(
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request."
                )
        else:
            self.structured_outputs._backend = backend

        # Request content validation
        if (
            isinstance(self.structured_outputs.choice, list)
            and not self.structured_outputs.choice
        ):
            # It is invalid for choice to be an empty list
            raise ValueError(
                f"Choice '{self.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(self.structured_outputs.grammar, str)
            and self.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")

        from vllm.tokenizers.mistral import MistralTokenizer
        from vllm.v1.structured_output.backend_guidance import (
            has_guidance_unsupported_json_features,
            validate_guidance_grammar,
        )
        from vllm.v1.structured_output.backend_lm_format_enforcer import (
            validate_structured_output_request_lm_format_enforcer,
        )
        from vllm.v1.structured_output.backend_outlines import (
            validate_structured_output_request_outlines,
        )
        from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar

        if backend.startswith("xgrammar"):
            # xgrammar with no fallback
            validate_xgrammar_grammar(self)
        elif backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            if isinstance(tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'guidance' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_guidance_grammar(self, tokenizer=None)
        elif backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(self)
        elif backend == "lm-format-enforcer":
            # lm format enforcer backend
            if isinstance(tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_structured_output_request_lm_format_enforcer(self)
        else:
            # NOTE: backend must be "auto" here, because we have
            # checked supported_backends above.
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
            try:
                validate_xgrammar_grammar(self)
                self.structured_outputs._backend = "xgrammar"
            except ValueError:
                # The request either failed validation
                # or includes some jsonschema feature(s) that
                # are not supported in xgrammar.

                # Check if schema has features unsupported by guidance
                so_params = self.structured_outputs
                skip_guidance = False
                if so_params.json:
                    if isinstance(so_params.json, str):
                        schema = json.loads(so_params.json)
                    else:
                        schema = so_params.json
                    skip_guidance = has_guidance_unsupported_json_features(schema)

                if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
                    # Fall back to outlines if the tokenizer is Mistral
                    # or if schema contains features unsupported by guidance
                    validate_structured_output_request_outlines(self)
                    self.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
                    validate_guidance_grammar(self, tokenizer=None)
                    self.structured_outputs._backend = "guidance"
            # Remember that this backend was set automatically
            self.structured_outputs._backend_was_auto = True

        # Run post-init validation. This is also important to ensure subsequent
        # roundtrip serialization/deserialization won't fail.
        self.structured_outputs.__post_init__()

808
    def __repr__(self) -> str:
809
810
811
812
813
814
815
816
817
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
818
            f"seed={self.seed}, "
819
820
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
821
            f"bad_words={self.bad_words}, "
822
823
824
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
825
            f"min_tokens={self.min_tokens}, "
826
827
828
829
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
830
            f"{self.spaces_between_special_tokens}, "
831
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
832
            f"structured_outputs={self.structured_outputs}, "
833
834
            f"extra_args={self.extra_args})"
        )
835
836
837


class BeamSearchParams(
838
839
840
841
842
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
843
    """Beam search parameters for text generation."""
844

845
846
847
848
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
849
    length_penalty: float = 1.0
850
    include_stop_str_in_output: bool = False