sampling_params.py 35.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import json
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Annotated, Any
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
16
from vllm.exceptions import VLLMValidationError
17
from vllm.logger import init_logger
18
from vllm.logits_process import LogitsProcessor
19
from vllm.tokenizers import TokenizerLike
20
from vllm.v1.serial_utils import PydanticMsgspecMixin
21
22
23

logger = init_logger(__name__)

24
_SAMPLING_EPS = 1e-5
25
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27

28
29
30
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
31
    RANDOM_SEED = 2
32
33


34
35
# maybe make msgspec?
@dataclass
36
37
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
38
39
40
41
42
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
43
    # These are other options that can be set.
44
45
46
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
47
48
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
49

50
    _backend: str | None = field(default=None, init=False)
51
52
53
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
54
55
56

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
57
58
59
60
61
62
63
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
64
                self.structural_tag is not None,
65
66
            ]
        )
67
        if count > 1:
68
            raise ValueError(
69
                "You can only use one kind of structured outputs constraint "
70
71
                f"but multiple are specified: {self.__dict__}"
            )
72
73
74
75
76
        if count < 1:
            raise ValueError(
                "You must use one kind of structured outputs constraint "
                f"but none are specified: {self.__dict__}"
            )
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

109

110
111
112
113
114
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
115
    # Do not return intermediate RequestOutput
116
117
118
    FINAL_ONLY = 2


119
class SamplingParams(
120
    PydanticMsgspecMixin,
121
122
123
124
125
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
126
127
128
129
130
131
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
132

133
    n: int = 1
134
135
136
137
138
139
140
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
141
    presence_penalty: float = 0.0
142
143
144
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
145
    frequency_penalty: float = 0.0
146
147
148
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
149
    repetition_penalty: float = 1.0
150
151
152
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
153
    temperature: float = 1.0
154
155
156
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
157
    top_p: float = 1.0
158
159
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
160
    top_k: int = 0
161
162
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
163
    min_p: float = 0.0
164
165
166
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
167
    seed: int | None = None
168
    """Random seed to use for the generation."""
169
    stop: str | list[str] | None = None
170
171
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
172
    stop_token_ids: list[int] | None = None
173
174
175
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
176
    ignore_eos: bool = False
177
178
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
179
    max_tokens: int | None = 16
180
    """Maximum number of tokens to generate per output sequence."""
181
    min_tokens: int = 0
182
183
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
184
    logprobs: int | None = None
185
186
187
188
189
190
191
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
192
    prompt_logprobs: int | None = None
193
194
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
195
196
197
198
199
200
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
201
202
203
204
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
205
    """Whether to detokenize the output."""
206
    skip_special_tokens: bool = True
207
    """Whether to skip special tokens in the output."""
208
    spaces_between_special_tokens: bool = True
209
    """Whether to add spaces between special tokens in the output."""
210
211
212
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
213
214
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
215
    include_stop_str_in_output: bool = False
216
    """Whether to include the stop strings in output text."""
217
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
218
219
220
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
221
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
222
223
224
225
226
227
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
228
229
230
231

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
232
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
233

234
    # Fields used to construct logits processors
235
    structured_outputs: StructuredOutputsParams | None = None
236
    """Parameters for configuring structured outputs."""
237
    logit_bias: dict[int, float] | None = None
238
239
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
240
    allowed_token_ids: list[int] | None = None
241
242
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
243
    extra_args: dict[str, Any] | None = None
244
245
246
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
247

248
    # Fields used for bad words
249
    bad_words: list[str] | None = None
250
251
252
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
253
    _bad_words_token_ids: list[list[int]] | None = None
254

255
    skip_reading_prefix_cache: bool | None = None
256

257
258
    @staticmethod
    def from_optional(
259
260
261
262
263
264
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
265
        top_k: int = 0,
266
        min_p: float = 0.0,
267
268
269
270
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
271
272
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
273
        max_tokens: int | None = 16,
274
        min_tokens: int = 0,
275
276
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
277
278
279
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
280
281
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
282
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
283
284
285
286
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
287
        skip_clone: bool = False,
288
    ) -> "SamplingParams":
289
        if logit_bias is not None:
290
291
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
292
            logit_bias = {
293
                int(token): min(100.0, max(-100.0, bias))
294
295
296
                for token, bias in logit_bias.items()
            }

297
298
        return SamplingParams(
            n=1 if n is None else n,
299
300
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
301
            repetition_penalty=1.0
302
303
            if repetition_penalty is None
            else repetition_penalty,
304
305
306
307
308
309
310
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
311
            bad_words=bad_words,
312
313
314
315
316
317
318
319
320
321
322
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
323
            output_kind=output_kind,
324
            structured_outputs=structured_outputs,
325
326
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
327
            extra_args=extra_args,
328
            skip_clone=skip_clone,
329
330
        )

331
332
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
333
334
335
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
336
337
338
339
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
340
            self.temperature = max(self.temperature, _MAX_TEMP)
341

342
        if self.seed == -1:
343
            self.seed = None
344

345
        if self.stop is None:
346
            self.stop = []
347
348
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
349

350
        if self.stop_token_ids is None:
351
            self.stop_token_ids = []
352
353
354
355

        if self.bad_words is None:
            self.bad_words = []

356
357
358
359
360
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
361

362
363
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
364
        if self.stop and not self.include_stop_str_in_output:
365
366
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

367
        self._verify_args()
368
369
370
371

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
372
            self.top_k = 0
373
374
            self.min_p = 0.0
            self._verify_greedy_sampling()
375

376
        # eos_token_id is added to this by the engine
377
        self._all_stop_token_ids.update(self.stop_token_ids)
378

379
380
381
382
383
384
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

385
    def _verify_args(self) -> None:
386
        if not isinstance(self.n, int):
387
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
388
389
390
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
391
392
393
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
394
        if not -2.0 <= self.frequency_penalty <= 2.0:
395
396
397
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
398
399
400
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
401
402
                f"{self.repetition_penalty}."
            )
403
        if self.temperature < 0.0:
404
405
406
407
            raise VLLMValidationError(
                f"temperature must be non-negative, got {self.temperature}.",
                parameter="temperature",
                value=self.temperature,
408
            )
409
        if not 0.0 < self.top_p <= 1.0:
410
411
412
413
414
            raise VLLMValidationError(
                f"top_p must be in (0, 1], got {self.top_p}.",
                parameter="top_p",
                value=self.top_p,
            )
415
416
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
417
418
419
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
420
421
        if not isinstance(self.top_k, int):
            raise TypeError(
422
423
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
424
        if not 0.0 <= self.min_p <= 1.0:
425
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
426
        if self.max_tokens is not None and self.max_tokens < 1:
427
428
429
430
431
            raise VLLMValidationError(
                f"max_tokens must be at least 1, got {self.max_tokens}.",
                parameter="max_tokens",
                value=self.max_tokens,
            )
432
        if self.min_tokens < 0:
433
434
435
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
436
437
438
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
439
440
441
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
442
443
444
445
            raise VLLMValidationError(
                f"logprobs must be non-negative or -1, got {self.logprobs}.",
                parameter="logprobs",
                value=self.logprobs,
446
447
448
449
450
451
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
452
            raise VLLMValidationError(
453
                f"prompt_logprobs must be non-negative or -1, got "
454
455
456
                f"{self.prompt_logprobs}.",
                parameter="prompt_logprobs",
                value=self.prompt_logprobs,
457
            )
458
459
460
461
462
        if self.logits_processors:
            # TODO: Remove `logits_processors` attribute
            raise ValueError(
                "vLLM V1 does not support per request user-provided logits processors."
            )
463
464
465
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
466
            raise VLLMValidationError(
467
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
468
469
470
                f"got {self.truncate_prompt_tokens}",
                parameter="truncate_prompt_tokens",
                value=self.truncate_prompt_tokens,
471
            )
472
473
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
474
475
476
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
477
        assert isinstance(self.stop, list)
478
479
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
480
481
482
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
483
484
                "Set detokenize=True to use stop."
            )
485
486

    def _verify_greedy_sampling(self) -> None:
487
        if self.n > 1:
488
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
489

490
    def update_from_generation_config(
491
492
        self,
        generation_config: dict[str, Any],
493
        model_eos_token_id: int | None = None,
494
    ) -> None:
495
        """Update if there are non-default values from generation_config"""
496
497
498
499

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
500
            self._all_stop_token_ids.add(model_eos_token_id)
501

502
        # Update eos_token_id for generation
503
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
504
            # it can be either int or list of int
505
506
507
508
509
510
511
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
512
                self._all_stop_token_ids.update(eos_ids)
513
514
515
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
516

517
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
518
        if not self.bad_words:
519
            return
520
        self._bad_words_token_ids = []
521
522
523
524
525
526
527
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
528
529
530
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
531
532
533
534

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
535
536
537
538
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
539
540
541
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
542
543
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
544
545
546
547
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
548
            raise VLLMValidationError(
549
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
550
551
552
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
553
554
555
                f" 0 <= token_id <= {tokenizer.max_token_id}.",
                parameter="bad_words",
                value=self.bad_words,
556
            )
557

558
559
560
561
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
562
563
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
564
565
        return SamplingType.RANDOM

566
    @property
567
    def all_stop_token_ids(self) -> set[int]:
568
569
        return self._all_stop_token_ids

570
    @property
571
    def bad_words_token_ids(self) -> list[list[int]] | None:
572
573
574
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

575
    def clone(self) -> "SamplingParams":
576
        """Deep copy, but maybe not the LogitsProcessor objects.
577

578
579
580
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
581
        See https://github.com/vllm-project/vllm/issues/3087
582
583

        If skip_clone is True, uses shallow copy instead of deep copy.
584
585
        """

586
587
588
        if self.skip_clone:
            return copy.copy(self)

589
590
591
592
593
594
595
596
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
597
598
        return copy.deepcopy(self, memo=logit_processor_refs)

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
    def verify(
        self,
        model_config: ModelConfig,
        speculative_config: SpeculativeConfig | None,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        self._validate_logprobs(model_config)
        self._validate_logit_bias(model_config)
        self._validate_allowed_token_ids(tokenizer)
        self._validate_spec_decode(speculative_config)
        self._validate_structured_outputs(structured_outputs_config, tokenizer)

    def _validate_logprobs(self, model_config: ModelConfig) -> None:
        max_logprobs = model_config.max_logprobs
        if max_logprobs == -1:
            max_logprobs = model_config.get_vocab_size()

        # Validate sample logprobs.
        if num_logprobs := self.logprobs:
            if num_logprobs == -1:
                num_logprobs = model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="logprobs",
                    value=num_logprobs,
                )

        # Validate prompt logprobs.
        if num_prompt_logprobs := self.prompt_logprobs:
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise VLLMValidationError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="prompt_logprobs",
                    value=num_prompt_logprobs,
                )

    def _validate_logit_bias(self, model_config: ModelConfig) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not self.logit_bias:
            return

        vocab_size = model_config.get_vocab_size()
        invalid_token_ids = [
            token_id
            for token_id in self.logit_bias
            if token_id < 0 or token_id >= vocab_size
        ]

        if invalid_token_ids:
            raise VLLMValidationError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
                parameter="logit_bias",
                value=invalid_token_ids,
            )

    def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None:
        allowed_token_ids = self.allowed_token_ids
        if allowed_token_ids is None:
            return

        if len(allowed_token_ids) == 0:
            raise VLLMValidationError(
                "allowed_token_ids is not None and empty!",
                parameter="allowed_token_ids",
                value=allowed_token_ids,
            )

        if tokenizer is not None:
            vocab_size = len(tokenizer)
            invalid_token_ids = [
                token_id
                for token_id in allowed_token_ids
                if token_id < 0 or token_id >= vocab_size
            ]
            if invalid_token_ids:
                raise VLLMValidationError(
                    "allowed_token_ids contains out-of-vocab token id!",
                    parameter="allowed_token_ids",
                    value=invalid_token_ids,
                )

    def _validate_spec_decode(
        self,
        speculative_config: SpeculativeConfig | None,
    ) -> None:
        if speculative_config is None:
            return

        # Some sampling parameters are not yet compatible with spec decoding.
        if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
            raise ValueError(
                "The min_tokens, min_p, and logit_bias sampling parameters "
                "are not yet supported with speculative decoding."
            )

    def _validate_structured_outputs(
        self,
        structured_outputs_config: StructuredOutputsConfig | None,
        tokenizer: TokenizerLike | None,
    ) -> None:
        if structured_outputs_config is None or self.structured_outputs is None:
            return

        if tokenizer is None:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

        backend = structured_outputs_config.backend
        if _backend := self.structured_outputs._backend:
            # Request-level backend selection is not supported.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_backend_was_auto` field set in the params.
            if backend != _backend and not (
                backend == "auto" and self.structured_outputs._backend_was_auto
            ):
                raise ValueError(
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request."
                )
        else:
            self.structured_outputs._backend = backend

        # Request content validation
        if (
            isinstance(self.structured_outputs.choice, list)
            and not self.structured_outputs.choice
        ):
            # It is invalid for choice to be an empty list
            raise ValueError(
                f"Choice '{self.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(self.structured_outputs.grammar, str)
            and self.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")

        from vllm.tokenizers.mistral import MistralTokenizer
        from vllm.v1.structured_output.backend_guidance import (
            has_guidance_unsupported_json_features,
            validate_guidance_grammar,
        )
        from vllm.v1.structured_output.backend_lm_format_enforcer import (
            validate_structured_output_request_lm_format_enforcer,
        )
        from vllm.v1.structured_output.backend_outlines import (
            validate_structured_output_request_outlines,
        )
        from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar

        if backend.startswith("xgrammar"):
            # xgrammar with no fallback
            validate_xgrammar_grammar(self)
        elif backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            if isinstance(tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'guidance' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_guidance_grammar(self, tokenizer=None)
        elif backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(self)
        elif backend == "lm-format-enforcer":
            # lm format enforcer backend
            if isinstance(tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_structured_output_request_lm_format_enforcer(self)
        else:
            # NOTE: backend must be "auto" here, because we have
            # checked supported_backends above.
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
            try:
                validate_xgrammar_grammar(self)
                self.structured_outputs._backend = "xgrammar"
            except ValueError:
                # The request either failed validation
                # or includes some jsonschema feature(s) that
                # are not supported in xgrammar.

                # Check if schema has features unsupported by guidance
                so_params = self.structured_outputs
                skip_guidance = False
                if so_params.json:
                    if isinstance(so_params.json, str):
                        schema = json.loads(so_params.json)
                    else:
                        schema = so_params.json
                    skip_guidance = has_guidance_unsupported_json_features(schema)

                if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
                    # Fall back to outlines if the tokenizer is Mistral
                    # or if schema contains features unsupported by guidance
                    validate_structured_output_request_outlines(self)
                    self.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
                    validate_guidance_grammar(self, tokenizer=None)
                    self.structured_outputs._backend = "guidance"
            # Remember that this backend was set automatically
            self.structured_outputs._backend_was_auto = True

        # Run post-init validation. This is also important to ensure subsequent
        # roundtrip serialization/deserialization won't fail.
        self.structured_outputs.__post_init__()

830
    def __repr__(self) -> str:
831
832
833
834
835
836
837
838
839
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
840
            f"seed={self.seed}, "
841
842
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
843
            f"bad_words={self.bad_words}, "
844
845
846
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
847
            f"min_tokens={self.min_tokens}, "
848
849
850
851
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
852
            f"{self.spaces_between_special_tokens}, "
853
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
854
            f"structured_outputs={self.structured_outputs}, "
855
856
            f"extra_args={self.extra_args})"
        )
857
858
859


class BeamSearchParams(
860
861
862
863
864
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
865
    """Beam search parameters for text generation."""
866

867
868
869
870
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
871
    length_penalty: float = 1.0
872
    include_stop_str_in_output: bool = False