sampling_params.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
from dataclasses import field
7
from enum import Enum, IntEnum
8
from functools import cached_property
9
from typing import Annotated, Any
10

11
import msgspec
12
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.exceptions import VLLMValidationError
15
from vllm.logger import init_logger
16
from vllm.logits_process import LogitsProcessor
17
from vllm.tokenizers import TokenizerLike
18
from vllm.v1.serial_utils import PydanticMsgspecMixin
19
20
21

logger = init_logger(__name__)

22
_SAMPLING_EPS = 1e-5
23
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25

26
27
28
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
29
    RANDOM_SEED = 2
30
31


32
33
# maybe make msgspec?
@dataclass
34
35
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
36
37
38
39
40
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
41
    # These are other options that can be set.
42
43
44
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
45
46
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
47

48
    _backend: str | None = field(default=None, init=False)
49
50
51
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
52
53
54

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
55
56
57
58
59
60
61
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
62
                self.structural_tag is not None,
63
64
            ]
        )
65
        if count > 1:
66
            raise ValueError(
67
                "You can only use one kind of structured outputs constraint "
68
69
                f"but multiple are specified: {self.__dict__}"
            )
70
71
72
73
74
        if count < 1:
            raise ValueError(
                "You must use one kind of structured outputs constraint "
                f"but none are specified: {self.__dict__}"
            )
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

107

108
109
110
111
112
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
113
    # Do not return intermediate RequestOutput
114
115
116
    FINAL_ONLY = 2


117
class SamplingParams(
118
    PydanticMsgspecMixin,
119
120
121
122
123
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
124
125
126
127
128
129
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
130

131
    n: int = 1
132
133
134
135
136
137
138
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
139
    presence_penalty: float = 0.0
140
141
142
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
143
    frequency_penalty: float = 0.0
144
145
146
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
147
    repetition_penalty: float = 1.0
148
149
150
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
151
    temperature: float = 1.0
152
153
154
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
155
    top_p: float = 1.0
156
157
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
158
    top_k: int = 0
159
160
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
161
    min_p: float = 0.0
162
163
164
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
165
    seed: int | None = None
166
    """Random seed to use for the generation."""
167
    stop: str | list[str] | None = None
168
169
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
170
    stop_token_ids: list[int] | None = None
171
172
173
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
174
    ignore_eos: bool = False
175
176
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
177
    max_tokens: int | None = 16
178
    """Maximum number of tokens to generate per output sequence."""
179
    min_tokens: int = 0
180
181
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
182
    logprobs: int | None = None
183
184
185
186
187
188
189
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
190
    prompt_logprobs: int | None = None
191
192
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
193
194
195
196
197
198
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
199
200
201
202
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
203
    """Whether to detokenize the output."""
204
    skip_special_tokens: bool = True
205
    """Whether to skip special tokens in the output."""
206
    spaces_between_special_tokens: bool = True
207
    """Whether to add spaces between special tokens in the output."""
208
209
210
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
211
212
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
213
    include_stop_str_in_output: bool = False
214
    """Whether to include the stop strings in output text."""
215
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
216
217
218
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
219
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
220
221
222
223
224
225
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
226
227
228
229

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
230
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
231

232
    # Fields used to construct logits processors
233
    structured_outputs: StructuredOutputsParams | None = None
234
    """Parameters for configuring structured outputs."""
235
    logit_bias: dict[int, float] | None = None
236
237
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
238
    allowed_token_ids: list[int] | None = None
239
240
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
241
    extra_args: dict[str, Any] | None = None
242
243
244
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
245

246
    # Fields used for bad words
247
    bad_words: list[str] | None = None
248
249
250
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
251
    _bad_words_token_ids: list[list[int]] | None = None
252

253
    skip_reading_prefix_cache: bool | None = None
254

255
256
    @staticmethod
    def from_optional(
257
258
259
260
261
262
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
263
        top_k: int = 0,
264
        min_p: float = 0.0,
265
266
267
268
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
269
270
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
271
        max_tokens: int | None = 16,
272
        min_tokens: int = 0,
273
274
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
275
276
277
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
278
279
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
280
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
281
282
283
284
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
285
        skip_clone: bool = False,
286
    ) -> "SamplingParams":
287
        if logit_bias is not None:
288
289
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
290
            logit_bias = {
291
                int(token): min(100.0, max(-100.0, bias))
292
293
294
                for token, bias in logit_bias.items()
            }

295
296
        return SamplingParams(
            n=1 if n is None else n,
297
298
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
299
            repetition_penalty=1.0
300
301
            if repetition_penalty is None
            else repetition_penalty,
302
303
304
305
306
307
308
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
309
            bad_words=bad_words,
310
311
312
313
314
315
316
317
318
319
320
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
321
            output_kind=output_kind,
322
            structured_outputs=structured_outputs,
323
324
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
325
            extra_args=extra_args,
326
            skip_clone=skip_clone,
327
328
        )

329
330
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
331
332
333
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
334
335
336
337
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
338
            self.temperature = max(self.temperature, _MAX_TEMP)
339

340
        if self.seed == -1:
341
            self.seed = None
342

343
        if self.stop is None:
344
            self.stop = []
345
346
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
347

348
        if self.stop_token_ids is None:
349
            self.stop_token_ids = []
350
351
352
353

        if self.bad_words is None:
            self.bad_words = []

354
355
356
357
358
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
359

360
361
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
362
        if self.stop and not self.include_stop_str_in_output:
363
364
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

365
        self._verify_args()
366
367
368
369

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
370
            self.top_k = 0
371
372
            self.min_p = 0.0
            self._verify_greedy_sampling()
373

374
        # eos_token_id is added to this by the engine
375
        self._all_stop_token_ids.update(self.stop_token_ids)
376

377
378
379
380
381
382
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

383
    def _verify_args(self) -> None:
384
        if not isinstance(self.n, int):
385
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
386
387
388
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
389
390
391
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
392
        if not -2.0 <= self.frequency_penalty <= 2.0:
393
394
395
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
396
397
398
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
399
400
                f"{self.repetition_penalty}."
            )
401
        if self.temperature < 0.0:
402
403
404
405
            raise VLLMValidationError(
                f"temperature must be non-negative, got {self.temperature}.",
                parameter="temperature",
                value=self.temperature,
406
            )
407
        if not 0.0 < self.top_p <= 1.0:
408
409
410
411
412
            raise VLLMValidationError(
                f"top_p must be in (0, 1], got {self.top_p}.",
                parameter="top_p",
                value=self.top_p,
            )
413
414
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
415
416
417
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
418
419
        if not isinstance(self.top_k, int):
            raise TypeError(
420
421
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
422
        if not 0.0 <= self.min_p <= 1.0:
423
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
424
        if self.max_tokens is not None and self.max_tokens < 1:
425
426
427
428
429
            raise VLLMValidationError(
                f"max_tokens must be at least 1, got {self.max_tokens}.",
                parameter="max_tokens",
                value=self.max_tokens,
            )
430
        if self.min_tokens < 0:
431
432
433
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
434
435
436
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
437
438
439
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
440
441
442
443
            raise VLLMValidationError(
                f"logprobs must be non-negative or -1, got {self.logprobs}.",
                parameter="logprobs",
                value=self.logprobs,
444
445
446
447
448
449
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
450
            raise VLLMValidationError(
451
                f"prompt_logprobs must be non-negative or -1, got "
452
453
454
                f"{self.prompt_logprobs}.",
                parameter="prompt_logprobs",
                value=self.prompt_logprobs,
455
456
457
458
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
459
            raise VLLMValidationError(
460
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
461
462
463
                f"got {self.truncate_prompt_tokens}",
                parameter="truncate_prompt_tokens",
                value=self.truncate_prompt_tokens,
464
            )
465
466
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
467
468
469
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
470
        assert isinstance(self.stop, list)
471
472
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
473
474
475
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
476
477
                "Set detokenize=True to use stop."
            )
478
479

    def _verify_greedy_sampling(self) -> None:
480
        if self.n > 1:
481
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
482

483
    def update_from_generation_config(
484
485
        self,
        generation_config: dict[str, Any],
486
        model_eos_token_id: int | None = None,
487
    ) -> None:
488
        """Update if there are non-default values from generation_config"""
489
490
491
492

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
493
            self._all_stop_token_ids.add(model_eos_token_id)
494

495
        # Update eos_token_id for generation
496
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
497
            # it can be either int or list of int
498
499
500
501
502
503
504
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
505
                self._all_stop_token_ids.update(eos_ids)
506
507
508
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
509

510
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
511
        if not self.bad_words:
512
            return
513
        self._bad_words_token_ids = []
514
515
516
517
518
519
520
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
521
522
523
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
524
525
526
527

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
528
529
530
531
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
532
533
534
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
535
536
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
537
538
539
540
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
541
            raise VLLMValidationError(
542
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
543
544
545
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
546
547
548
                f" 0 <= token_id <= {tokenizer.max_token_id}.",
                parameter="bad_words",
                value=self.bad_words,
549
            )
550

551
552
553
554
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
555
556
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
557
558
        return SamplingType.RANDOM

559
    @property
560
    def all_stop_token_ids(self) -> set[int]:
561
562
        return self._all_stop_token_ids

563
    @property
564
    def bad_words_token_ids(self) -> list[list[int]] | None:
565
566
567
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

568
    def clone(self) -> "SamplingParams":
569
        """Deep copy, but maybe not the LogitsProcessor objects.
570

571
572
573
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
574
        See https://github.com/vllm-project/vllm/issues/3087
575
576

        If skip_clone is True, uses shallow copy instead of deep copy.
577
578
        """

579
580
581
        if self.skip_clone:
            return copy.copy(self)

582
583
584
585
586
587
588
589
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
590
591
        return copy.deepcopy(self, memo=logit_processor_refs)

592
    def __repr__(self) -> str:
593
594
595
596
597
598
599
600
601
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
602
            f"seed={self.seed}, "
603
604
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
605
            f"bad_words={self.bad_words}, "
606
607
608
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
609
            f"min_tokens={self.min_tokens}, "
610
611
612
613
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
614
            f"{self.spaces_between_special_tokens}, "
615
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
616
            f"structured_outputs={self.structured_outputs}, "
617
618
            f"extra_args={self.extra_args})"
        )
619
620
621


class BeamSearchParams(
622
623
624
625
626
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
627
    """Beam search parameters for text generation."""
628

629
630
631
632
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
633
    length_penalty: float = 1.0
634
    include_stop_str_in_output: bool = False