sampling_params.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
from dataclasses import field
7
from enum import Enum, IntEnum
8
from functools import cached_property
9
from typing import Annotated, Any
10

11
import msgspec
12
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.exceptions import VLLMValidationError
15
from vllm.logger import init_logger
16
from vllm.logits_process import LogitsProcessor
17
from vllm.tokenizers import TokenizerLike
18
from vllm.v1.serial_utils import PydanticMsgspecMixin
19
20
21

logger = init_logger(__name__)

22
_SAMPLING_EPS = 1e-5
23
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25

26
27
28
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
29
    RANDOM_SEED = 2
30
31


32
33
# maybe make msgspec?
@dataclass
34
35
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
36
37
38
39
40
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
41
    # These are other options that can be set.
42
43
44
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
45
46
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
47

48
    _backend: str | None = field(default=None, init=False)
49
50
51
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
52
53
54

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
55
56
57
58
59
60
61
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
62
                self.structural_tag is not None,
63
64
            ]
        )
65
        if count > 1:
66
            raise ValueError(
67
                "You can only use one kind of structured outputs constraint "
68
69
                f"but multiple are specified: {self.__dict__}"
            )
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

102

103
104
105
106
107
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
108
    # Do not return intermediate RequestOutput
109
110
111
    FINAL_ONLY = 2


112
class SamplingParams(
113
    PydanticMsgspecMixin,
114
115
116
117
118
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
119
120
121
122
123
124
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
125

126
    n: int = 1
127
128
129
130
131
132
133
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
134
    presence_penalty: float = 0.0
135
136
137
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
138
    frequency_penalty: float = 0.0
139
140
141
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
142
    repetition_penalty: float = 1.0
143
144
145
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
146
    temperature: float = 1.0
147
148
149
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
150
    top_p: float = 1.0
151
152
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
153
    top_k: int = 0
154
155
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
156
    min_p: float = 0.0
157
158
159
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
160
    seed: int | None = None
161
    """Random seed to use for the generation."""
162
    stop: str | list[str] | None = None
163
164
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
165
    stop_token_ids: list[int] | None = None
166
167
168
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
169
    ignore_eos: bool = False
170
171
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
172
    max_tokens: int | None = 16
173
    """Maximum number of tokens to generate per output sequence."""
174
    min_tokens: int = 0
175
176
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
177
    logprobs: int | None = None
178
179
180
181
182
183
184
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
185
    prompt_logprobs: int | None = None
186
187
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
188
189
190
191
192
193
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
194
195
196
197
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
198
    """Whether to detokenize the output."""
199
    skip_special_tokens: bool = True
200
    """Whether to skip special tokens in the output."""
201
    spaces_between_special_tokens: bool = True
202
    """Whether to add spaces between special tokens in the output."""
203
204
205
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
206
207
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
208
    include_stop_str_in_output: bool = False
209
    """Whether to include the stop strings in output text."""
210
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
211
212
213
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
214
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
215
216
217
218
219
220
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
221
222
223
224

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
225
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
226

227
    # Fields used to construct logits processors
228
    structured_outputs: StructuredOutputsParams | None = None
229
    """Parameters for configuring structured outputs."""
230
    logit_bias: dict[int, float] | None = None
231
232
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
233
    allowed_token_ids: list[int] | None = None
234
235
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
236
    extra_args: dict[str, Any] | None = None
237
238
239
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
240

241
    # Fields used for bad words
242
    bad_words: list[str] | None = None
243
244
245
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
246
    _bad_words_token_ids: list[list[int]] | None = None
247

248
    skip_reading_prefix_cache: bool | None = None
249

250
251
    @staticmethod
    def from_optional(
252
253
254
255
256
257
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
258
        top_k: int = 0,
259
        min_p: float = 0.0,
260
261
262
263
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
264
265
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
266
        max_tokens: int | None = 16,
267
        min_tokens: int = 0,
268
269
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
270
271
272
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
273
274
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
275
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
276
277
278
279
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
280
        skip_clone: bool = False,
281
    ) -> "SamplingParams":
282
        if logit_bias is not None:
283
284
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
285
            logit_bias = {
286
                int(token): min(100.0, max(-100.0, bias))
287
288
289
                for token, bias in logit_bias.items()
            }

290
291
        return SamplingParams(
            n=1 if n is None else n,
292
293
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
294
            repetition_penalty=1.0
295
296
            if repetition_penalty is None
            else repetition_penalty,
297
298
299
300
301
302
303
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
304
            bad_words=bad_words,
305
306
307
308
309
310
311
312
313
314
315
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
316
            output_kind=output_kind,
317
            structured_outputs=structured_outputs,
318
319
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
320
            extra_args=extra_args,
321
            skip_clone=skip_clone,
322
323
        )

324
325
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
326
327
328
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
329
330
331
332
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
333
            self.temperature = max(self.temperature, _MAX_TEMP)
334

335
        if self.seed == -1:
336
            self.seed = None
337

338
        if self.stop is None:
339
            self.stop = []
340
341
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
342

343
        if self.stop_token_ids is None:
344
            self.stop_token_ids = []
345
346
347
348

        if self.bad_words is None:
            self.bad_words = []

349
350
351
352
353
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
354

355
356
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
357
        if self.stop and not self.include_stop_str_in_output:
358
359
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

360
        self._verify_args()
361
362
363
364

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
365
            self.top_k = 0
366
367
            self.min_p = 0.0
            self._verify_greedy_sampling()
368

369
        # eos_token_id is added to this by the engine
370
        self._all_stop_token_ids.update(self.stop_token_ids)
371

372
373
374
375
376
377
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

378
    def _verify_args(self) -> None:
379
        if not isinstance(self.n, int):
380
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
381
382
383
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
384
385
386
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
387
        if not -2.0 <= self.frequency_penalty <= 2.0:
388
389
390
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
391
392
393
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
394
395
                f"{self.repetition_penalty}."
            )
396
        if self.temperature < 0.0:
397
398
399
400
            raise VLLMValidationError(
                f"temperature must be non-negative, got {self.temperature}.",
                parameter="temperature",
                value=self.temperature,
401
            )
402
        if not 0.0 < self.top_p <= 1.0:
403
404
405
406
407
            raise VLLMValidationError(
                f"top_p must be in (0, 1], got {self.top_p}.",
                parameter="top_p",
                value=self.top_p,
            )
408
409
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
410
411
412
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
413
414
        if not isinstance(self.top_k, int):
            raise TypeError(
415
416
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
417
        if not 0.0 <= self.min_p <= 1.0:
418
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
419
        if self.max_tokens is not None and self.max_tokens < 1:
420
421
422
423
424
            raise VLLMValidationError(
                f"max_tokens must be at least 1, got {self.max_tokens}.",
                parameter="max_tokens",
                value=self.max_tokens,
            )
425
        if self.min_tokens < 0:
426
427
428
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
429
430
431
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
432
433
434
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
435
436
437
438
            raise VLLMValidationError(
                f"logprobs must be non-negative or -1, got {self.logprobs}.",
                parameter="logprobs",
                value=self.logprobs,
439
440
441
442
443
444
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
445
            raise VLLMValidationError(
446
                f"prompt_logprobs must be non-negative or -1, got "
447
448
449
                f"{self.prompt_logprobs}.",
                parameter="prompt_logprobs",
                value=self.prompt_logprobs,
450
451
452
453
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
454
            raise VLLMValidationError(
455
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
456
457
458
                f"got {self.truncate_prompt_tokens}",
                parameter="truncate_prompt_tokens",
                value=self.truncate_prompt_tokens,
459
            )
460
461
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
462
463
464
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
465
        assert isinstance(self.stop, list)
466
467
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
468
469
470
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
471
472
                "Set detokenize=True to use stop."
            )
473
474

    def _verify_greedy_sampling(self) -> None:
475
        if self.n > 1:
476
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
477

478
    def update_from_generation_config(
479
480
        self,
        generation_config: dict[str, Any],
481
        model_eos_token_id: int | None = None,
482
    ) -> None:
483
        """Update if there are non-default values from generation_config"""
484
485
486
487

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
488
            self._all_stop_token_ids.add(model_eos_token_id)
489

490
        # Update eos_token_id for generation
491
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
492
            # it can be either int or list of int
493
494
495
496
497
498
499
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
500
                self._all_stop_token_ids.update(eos_ids)
501
502
503
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
504

505
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
506
        if not self.bad_words:
507
            return
508
        self._bad_words_token_ids = []
509
510
511
512
513
514
515
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
516
517
518
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
519
520
521
522

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
523
524
525
526
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
527
528
529
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
530
531
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
532
533
534
535
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
536
            raise VLLMValidationError(
537
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
538
539
540
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
541
542
543
                f" 0 <= token_id <= {tokenizer.max_token_id}.",
                parameter="bad_words",
                value=self.bad_words,
544
            )
545

546
547
548
549
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
550
551
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
552
553
        return SamplingType.RANDOM

554
    @property
555
    def all_stop_token_ids(self) -> set[int]:
556
557
        return self._all_stop_token_ids

558
    @property
559
    def bad_words_token_ids(self) -> list[list[int]] | None:
560
561
562
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

563
    def clone(self) -> "SamplingParams":
564
        """Deep copy, but maybe not the LogitsProcessor objects.
565

566
567
568
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
569
        See https://github.com/vllm-project/vllm/issues/3087
570
571

        If skip_clone is True, uses shallow copy instead of deep copy.
572
573
        """

574
575
576
        if self.skip_clone:
            return copy.copy(self)

577
578
579
580
581
582
583
584
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
585
586
        return copy.deepcopy(self, memo=logit_processor_refs)

587
    def __repr__(self) -> str:
588
589
590
591
592
593
594
595
596
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
597
            f"seed={self.seed}, "
598
599
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
600
            f"bad_words={self.bad_words}, "
601
602
603
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
604
            f"min_tokens={self.min_tokens}, "
605
606
607
608
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
609
            f"{self.spaces_between_special_tokens}, "
610
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
611
            f"structured_outputs={self.structured_outputs}, "
612
613
            f"extra_args={self.extra_args})"
        )
614
615
616


class BeamSearchParams(
617
618
619
620
621
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
622
    """Beam search parameters for text generation."""
623

624
625
626
627
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
628
    length_penalty: float = 1.0
629
    include_stop_str_in_output: bool = False