sampling_params.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
from dataclasses import field
7
from enum import Enum, IntEnum
8
from functools import cached_property
9
from typing import Annotated, Any
10

11
import msgspec
12
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.logger import init_logger
15
from vllm.logits_process import LogitsProcessor
16
from vllm.tokenizers import TokenizerLike
17
from vllm.v1.serial_utils import PydanticMsgspecMixin
18
19
20

logger = init_logger(__name__)

21
_SAMPLING_EPS = 1e-5
22
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24

25
26
27
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
28
    RANDOM_SEED = 2
29
30


31
32
# maybe make msgspec?
@dataclass
33
34
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
35
36
37
38
39
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
40
    # These are other options that can be set.
41
42
43
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
44
45
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
46

47
    _backend: str | None = field(default=None, init=False)
48
49
50
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
51
52
53

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
54
55
56
57
58
59
60
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
61
                self.structural_tag is not None,
62
63
            ]
        )
64
        if count > 1:
65
            raise ValueError(
66
                "You can only use one kind of structured outputs constraint "
67
68
                f"but multiple are specified: {self.__dict__}"
            )
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

101

102
103
104
105
106
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
107
    # Do not return intermediate RequestOutput
108
109
110
    FINAL_ONLY = 2


111
class SamplingParams(
112
    PydanticMsgspecMixin,
113
114
115
116
117
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
118
119
120
121
122
123
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
124

125
    n: int = 1
126
127
128
129
130
131
132
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
133
    presence_penalty: float = 0.0
134
135
136
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
137
    frequency_penalty: float = 0.0
138
139
140
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
141
    repetition_penalty: float = 1.0
142
143
144
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
145
    temperature: float = 1.0
146
147
148
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
149
    top_p: float = 1.0
150
151
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
152
    top_k: int = 0
153
154
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
155
    min_p: float = 0.0
156
157
158
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
159
    seed: int | None = None
160
    """Random seed to use for the generation."""
161
    stop: str | list[str] | None = None
162
163
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
164
    stop_token_ids: list[int] | None = None
165
166
167
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
168
    ignore_eos: bool = False
169
170
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
171
    max_tokens: int | None = 16
172
    """Maximum number of tokens to generate per output sequence."""
173
    min_tokens: int = 0
174
175
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
176
    logprobs: int | None = None
177
178
179
180
181
182
183
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
184
    prompt_logprobs: int | None = None
185
186
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
187
188
189
190
191
192
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
193
194
195
196
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
197
    """Whether to detokenize the output."""
198
    skip_special_tokens: bool = True
199
    """Whether to skip special tokens in the output."""
200
    spaces_between_special_tokens: bool = True
201
    """Whether to add spaces between special tokens in the output."""
202
203
204
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
205
206
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
207
    include_stop_str_in_output: bool = False
208
    """Whether to include the stop strings in output text."""
209
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
210
211
212
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
213
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
214
215
216
217
218
219
    skip_clone: bool = False
    """Internal flag indicating that this SamplingParams instance is safe to
    reuse without cloning. When True, clone() will return self without
    performing a deep copy. This should only be set when the params object
    is guaranteed to be dedicated to a single request and won't be modified
    in ways that would affect other uses."""
220
221
222
223

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
224
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
225

226
    # Fields used to construct logits processors
227
    structured_outputs: StructuredOutputsParams | None = None
228
    """Parameters for configuring structured outputs."""
229
    logit_bias: dict[int, float] | None = None
230
231
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
232
    allowed_token_ids: list[int] | None = None
233
234
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
235
    extra_args: dict[str, Any] | None = None
236
237
238
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
239

240
    # Fields used for bad words
241
    bad_words: list[str] | None = None
242
243
244
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
245
    _bad_words_token_ids: list[list[int]] | None = None
246

247
    skip_reading_prefix_cache: bool | None = None
248

249
250
    @staticmethod
    def from_optional(
251
252
253
254
255
256
        n: int | None = 1,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
257
        top_k: int = 0,
258
        min_p: float = 0.0,
259
260
261
262
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
263
264
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
265
        max_tokens: int | None = 16,
266
        min_tokens: int = 0,
267
268
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
269
270
271
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
272
273
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
274
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
275
276
277
278
        structured_outputs: StructuredOutputsParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
279
        skip_clone: bool = False,
280
    ) -> "SamplingParams":
281
        if logit_bias is not None:
282
283
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
284
            logit_bias = {
285
                int(token): min(100.0, max(-100.0, bias))
286
287
288
                for token, bias in logit_bias.items()
            }

289
290
        return SamplingParams(
            n=1 if n is None else n,
291
292
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
293
            repetition_penalty=1.0
294
295
            if repetition_penalty is None
            else repetition_penalty,
296
297
298
299
300
301
302
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
303
            bad_words=bad_words,
304
305
306
307
308
309
310
311
312
313
314
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
315
            output_kind=output_kind,
316
            structured_outputs=structured_outputs,
317
318
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
319
            extra_args=extra_args,
320
            skip_clone=skip_clone,
321
322
        )

323
324
    def __post_init__(self) -> None:
        if 0 < self.temperature < _MAX_TEMP:
325
326
327
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
328
329
330
331
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
332
            self.temperature = max(self.temperature, _MAX_TEMP)
333

334
        if self.seed == -1:
335
            self.seed = None
336

337
        if self.stop is None:
338
            self.stop = []
339
340
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
341

342
        if self.stop_token_ids is None:
343
            self.stop_token_ids = []
344
345
346
347

        if self.bad_words is None:
            self.bad_words = []

348
349
350
351
352
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
353

354
355
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
356
        if self.stop and not self.include_stop_str_in_output:
357
358
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

359
        self._verify_args()
360
361
362
363

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
364
            self.top_k = 0
365
366
            self.min_p = 0.0
            self._verify_greedy_sampling()
367

368
        # eos_token_id is added to this by the engine
369
        self._all_stop_token_ids.update(self.stop_token_ids)
370

371
372
373
374
375
376
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

377
    def _verify_args(self) -> None:
378
        if not isinstance(self.n, int):
379
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
380
381
382
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
        if not -2.0 <= self.presence_penalty <= 2.0:
383
384
385
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
386
        if not -2.0 <= self.frequency_penalty <= 2.0:
387
388
389
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
390
391
392
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
393
394
                f"{self.repetition_penalty}."
            )
395
396
        if self.temperature < 0.0:
            raise ValueError(
397
398
                f"temperature must be non-negative, got {self.temperature}."
            )
399
400
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
401
402
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
403
404
405
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
406
407
        if not isinstance(self.top_k, int):
            raise TypeError(
408
409
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
410
        if not 0.0 <= self.min_p <= 1.0:
411
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
412
        if self.max_tokens is not None and self.max_tokens < 1:
413
            raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
414
        if self.min_tokens < 0:
415
416
417
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
418
419
420
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
421
422
423
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
424
            raise ValueError(
425
426
427
428
429
430
431
                f"logprobs must be non-negative or -1, got {self.logprobs}."
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
432
433
            raise ValueError(
                f"prompt_logprobs must be non-negative or -1, got "
434
435
436
437
438
                f"{self.prompt_logprobs}."
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
439
440
            raise ValueError(
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
441
442
                f"got {self.truncate_prompt_tokens}"
            )
443
444
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
445
446
447
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
448
        assert isinstance(self.stop, list)
449
450
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
451
452
453
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
454
455
                "Set detokenize=True to use stop."
            )
456
457

    def _verify_greedy_sampling(self) -> None:
458
        if self.n > 1:
459
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
460

461
    def update_from_generation_config(
462
463
        self,
        generation_config: dict[str, Any],
464
        model_eos_token_id: int | None = None,
465
    ) -> None:
466
        """Update if there are non-default values from generation_config"""
467
468
469
470

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
471
            self._all_stop_token_ids.add(model_eos_token_id)
472

473
        # Update eos_token_id for generation
474
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
475
            # it can be either int or list of int
476
477
478
479
480
481
482
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
483
                self._all_stop_token_ids.update(eos_ids)
484
485
486
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
487

488
    def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
489
        if not self.bad_words:
490
            return
491
        self._bad_words_token_ids = []
492
493
494
495
496
497
498
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
499
500
501
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
502
503
504
505

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
506
507
508
509
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
510
511
512
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
513
514
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
515
516
517
518
519
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
            raise ValueError(
520
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
521
522
523
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
524
525
                f" 0 <= token_id <= {tokenizer.max_token_id}."
            )
526

527
528
529
530
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
531
532
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
533
534
        return SamplingType.RANDOM

535
    @property
536
    def all_stop_token_ids(self) -> set[int]:
537
538
        return self._all_stop_token_ids

539
    @property
540
    def bad_words_token_ids(self) -> list[list[int]] | None:
541
542
543
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

544
    def clone(self) -> "SamplingParams":
545
        """Deep copy, but maybe not the LogitsProcessor objects.
546

547
548
549
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
550
        See https://github.com/vllm-project/vllm/issues/3087
551
552

        If skip_clone is True, uses shallow copy instead of deep copy.
553
554
        """

555
556
557
        if self.skip_clone:
            return copy.copy(self)

558
559
560
561
562
563
564
565
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
566
567
        return copy.deepcopy(self, memo=logit_processor_refs)

568
    def __repr__(self) -> str:
569
570
571
572
573
574
575
576
577
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
578
            f"seed={self.seed}, "
579
580
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
581
            f"bad_words={self.bad_words}, "
582
583
584
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
585
            f"min_tokens={self.min_tokens}, "
586
587
588
589
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
590
            f"{self.spaces_between_special_tokens}, "
591
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
592
            f"structured_outputs={self.structured_outputs}, "
593
594
            f"extra_args={self.extra_args})"
        )
595
596
597


class BeamSearchParams(
598
599
600
601
602
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
603
    """Beam search parameters for text generation."""
604

605
606
607
608
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
609
    length_penalty: float = 1.0
610
    include_stop_str_in_output: bool = False