sampling_params.py 25.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import warnings
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Annotated, Any, Optional, Union
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
from vllm.logger import init_logger
16
from vllm.logits_process import LogitsProcessor
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
19
20

logger = init_logger(__name__)

21
_SAMPLING_EPS = 1e-5
22
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24

25
26
27
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
28
    RANDOM_SEED = 2
29
30


31
32
# maybe make msgspec?
@dataclass
33
34
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
35
    json: Optional[Union[str, dict]] = None
36
    regex: Optional[str] = None
37
    choice: Optional[list[str]] = None
38
39
    grammar: Optional[str] = None
    json_object: Optional[bool] = None
40
    # These are other options that can be set.
41
42
43
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
44
    whitespace_pattern: Optional[str] = None
45
    structural_tag: Optional[str] = None
46

47
48
49
50
    _backend: Optional[str] = field(default=None, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
51
52
53

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
54
55
56
57
58
59
60
61
62
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
            ]
        )
63
        if count > 1:
64
            raise ValueError(
65
                "You can only use one kind of structured outputs constraint "
66
67
                f"but multiple are specified: {self.__dict__}"
            )
68
69


70
71
72
73
74
75
76
77
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
    def __post_init__(self):
        warnings.warn(
            "GuidedDecodingParams is deprecated. This will be removed in "
            "v0.12.0 or v1.0.0, which ever is soonest. Please use "
            "StructuredOutputsParams instead.",
            DeprecationWarning,
78
79
            stacklevel=2,
        )
80
81
82
        return super().__post_init__()


83
84
85
86
87
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
88
    # Do not return intermediate RequestOutput
89
90
91
    FINAL_ONLY = 2


92
class SamplingParams(
93
94
95
96
97
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
98
99
100
101
102
103
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
104

105
    n: int = 1
106
107
108
109
110
111
112
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
113
    best_of: Optional[int] = None
114
115
116
117
    """Number of output sequences that are generated from the prompt. From
    these `best_of` sequences, the top `n` sequences are returned. `best_of`
    must be greater than or equal to `n`. By default, `best_of` is set to `n`.
    Warning, this is only supported in V0."""
118
    _real_n: Optional[int] = None
119
    presence_penalty: float = 0.0
120
121
122
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
123
    frequency_penalty: float = 0.0
124
125
126
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
127
    repetition_penalty: float = 1.0
128
129
130
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
131
    temperature: float = 1.0
132
133
134
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
135
    top_p: float = 1.0
136
137
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
138
    top_k: int = 0
139
140
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
141
    min_p: float = 0.0
142
143
144
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
145
    seed: Optional[int] = None
146
    """Random seed to use for the generation."""
147
    stop: Optional[Union[str, list[str]]] = None
148
149
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
150
    stop_token_ids: Optional[list[int]] = None
151
152
153
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
154
    ignore_eos: bool = False
155
156
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
157
    max_tokens: Optional[int] = 16
158
    """Maximum number of tokens to generate per output sequence."""
159
    min_tokens: int = 0
160
161
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
162
    logprobs: Optional[int] = None
163
164
165
166
167
168
169
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
170
    prompt_logprobs: Optional[int] = None
171
172
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
173
174
175
176
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
177
    """Whether to detokenize the output."""
178
    skip_special_tokens: bool = True
179
    """Whether to skip special tokens in the output."""
180
    spaces_between_special_tokens: bool = True
181
    """Whether to add spaces between special tokens in the output."""
182
183
    # Optional[list[LogitsProcessor]] type. We use Any here because
    # Optional[list[LogitsProcessor]] type is not supported by msgspec.
184
    logits_processors: Optional[Any] = None
185
186
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
187
    include_stop_str_in_output: bool = False
188
    """Whether to include the stop strings in output text."""
189
    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None
190
191
192
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
193
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
194
195
196
197

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
198
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
199

200
    # Fields used to construct logits processors
201
202
    structured_outputs: Optional[StructuredOutputsParams] = None
    """Parameters for configuring structured outputs."""
203
204
    guided_decoding: Optional[GuidedDecodingParams] = None
    """Deprecated alias for structured_outputs."""
205
    logit_bias: Optional[dict[int, float]] = None
206
207
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
208
    allowed_token_ids: Optional[list[int]] = None
209
210
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
211
    extra_args: Optional[dict[str, Any]] = None
212
213
214
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
215

216
217
    # Fields used for bad words
    bad_words: Optional[list[str]] = None
218
219
220
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
221
    _bad_words_token_ids: Optional[list[list[int]]] = None
222

223
224
225
    @staticmethod
    def from_optional(
        n: Optional[int] = 1,
226
        best_of: Optional[int] = None,
227
228
229
230
231
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
        repetition_penalty: Optional[float] = 1.0,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
232
        top_k: int = 0,
233
234
        min_p: float = 0.0,
        seed: Optional[int] = None,
235
236
237
        stop: Optional[Union[str, list[str]]] = None,
        stop_token_ids: Optional[list[int]] = None,
        bad_words: Optional[list[str]] = None,
238
239
240
241
242
243
244
245
246
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
        max_tokens: Optional[int] = 16,
        min_tokens: int = 0,
        logprobs: Optional[int] = None,
        prompt_logprobs: Optional[int] = None,
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
247
        logits_processors: Optional[list[LogitsProcessor]] = None,
248
        truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None,
249
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
250
        structured_outputs: Optional[StructuredOutputsParams] = None,
251
        guided_decoding: Optional[GuidedDecodingParams] = None,
252
253
        logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
        allowed_token_ids: Optional[list[int]] = None,
254
        extra_args: Optional[dict[str, Any]] = None,
255
    ) -> "SamplingParams":
256
        if logit_bias is not None:
257
258
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
259
            logit_bias = {
260
                int(token): min(100.0, max(-100.0, bias))
261
262
                for token, bias in logit_bias.items()
            }
263
264
265
266
267
268
        if guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
269
270
                stacklevel=2,
            )
271
272
            structured_outputs = guided_decoding
            guided_decoding = None
273

274
275
        return SamplingParams(
            n=1 if n is None else n,
276
            best_of=best_of,
277
278
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
279
            repetition_penalty=1.0
280
281
            if repetition_penalty is None
            else repetition_penalty,
282
283
284
285
286
287
288
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
289
            bad_words=bad_words,
290
291
292
293
294
295
296
297
298
299
300
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
301
            output_kind=output_kind,
302
            structured_outputs=structured_outputs,
303
304
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
305
            extra_args=extra_args,
306
307
        )

308
    def __post_init__(self) -> None:
309
310
311
312
313
314
315
316
317
318
        # how we deal with `best_of``:
        # if `best_of`` is not set, we default to `n`;
        # if `best_of`` is set, we set `n`` to `best_of`,
        # and set `_real_n`` to the original `n`.
        # when we return the result, we will check
        # if we need to return `n` or `_real_n` results
        if self.best_of:
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
319
320
                    f"got n={self.n} and best_of={self.best_of}."
                )
321
322
323
            if not self._real_n:
                self._real_n = self.n
                self.n = self.best_of
324

325
        if 0 < self.temperature < _MAX_TEMP:
326
327
328
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
329
330
331
332
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
333
            self.temperature = max(self.temperature, _MAX_TEMP)
334

335
        if self.seed == -1:
336
            self.seed = None
337

338
        if self.stop is None:
339
            self.stop = []
340
341
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
342

343
        if self.stop_token_ids is None:
344
            self.stop_token_ids = []
345
346
347
348

        if self.bad_words is None:
            self.bad_words = []

349
350
351
352
353
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
354

355
356
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
357
        if self.stop and not self.include_stop_str_in_output:
358
359
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

360
        self._verify_args()
361
362
363
364

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
365
            self.top_k = 0
366
367
            self.min_p = 0.0
            self._verify_greedy_sampling()
368

369
        # eos_token_id is added to this by the engine
370
        self._all_stop_token_ids.update(self.stop_token_ids)
371

372
373
374
375
376
377
        if self.guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
378
379
                stacklevel=2,
            )
380
381
382
            self.structured_outputs = self.guided_decoding
            self.guided_decoding = None

383
    def _verify_args(self) -> None:
384
        if not isinstance(self.n, int):
385
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
386
387
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
388
389
390
        if self.best_of is not None:
            if not isinstance(self.best_of, int):
                raise ValueError(
391
392
                    f"best_of must be an integer, got {type(self.best_of)}"
                )
393
            if self.best_of < 1:
394
                raise ValueError(f"best_of must be at least 1, got {self.best_of}")
395
396
397
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
398
399
                    f"got n={self.n} and best_of={self.best_of}."
                )
400
        if not -2.0 <= self.presence_penalty <= 2.0:
401
402
403
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
404
        if not -2.0 <= self.frequency_penalty <= 2.0:
405
406
407
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
408
409
410
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
411
412
                f"{self.repetition_penalty}."
            )
413
414
        if self.temperature < 0.0:
            raise ValueError(
415
416
                f"temperature must be non-negative, got {self.temperature}."
            )
417
418
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
419
420
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
421
422
423
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
424
425
        if not isinstance(self.top_k, int):
            raise TypeError(
426
427
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
428
        if not 0.0 <= self.min_p <= 1.0:
429
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
430
        if self.max_tokens is not None and self.max_tokens < 1:
431
            raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
432
        if self.min_tokens < 0:
433
434
435
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
436
437
438
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
439
440
441
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
442
            raise ValueError(
443
444
445
446
447
448
449
                f"logprobs must be non-negative or -1, got {self.logprobs}."
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
450
451
            raise ValueError(
                f"prompt_logprobs must be non-negative or -1, got "
452
453
454
455
456
                f"{self.prompt_logprobs}."
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
457
458
            raise ValueError(
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
459
460
                f"got {self.truncate_prompt_tokens}"
            )
461
462
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
463
464
465
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
466
        assert isinstance(self.stop, list)
467
468
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
469
470
471
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
472
473
                "Set detokenize=True to use stop."
            )
474
        if self.best_of != self._real_n and self.output_kind == (
475
476
            RequestOutputKind.DELTA
        ):
477
            raise ValueError("best_of must equal n to use output_kind=DELTA")
478
479

    def _verify_greedy_sampling(self) -> None:
480
        if self.n > 1:
481
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
482

483
    def update_from_generation_config(
484
485
486
487
        self,
        generation_config: dict[str, Any],
        model_eos_token_id: Optional[int] = None,
    ) -> None:
488
        """Update if there are non-default values from generation_config"""
489
490
491
492

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
493
            self._all_stop_token_ids.add(model_eos_token_id)
494

495
        # Update eos_token_id for generation
496
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
497
            # it can be either int or list of int
498
499
500
501
502
503
504
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
505
                self._all_stop_token_ids.update(eos_ids)
506
507
508
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
509

510
    def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
511
        if not self.bad_words:
512
            return
513
        self._bad_words_token_ids = []
514
515
516
517
518
519
520
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
521
522
523
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
524
525
526
527

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
528
529
530
531
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
532
533
534
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
535
536
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
537
538
539
540
541
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
            raise ValueError(
542
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
543
544
545
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
546
547
                f" 0 <= token_id <= {tokenizer.max_token_id}."
            )
548

549
550
551
552
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
553
554
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
555
556
        return SamplingType.RANDOM

557
    @property
558
    def all_stop_token_ids(self) -> set[int]:
559
560
        return self._all_stop_token_ids

561
    @property
562
    def bad_words_token_ids(self) -> Optional[list[list[int]]]:
563
564
565
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

566
    def clone(self) -> "SamplingParams":
567
        """Deep copy, but maybe not the LogitsProcessor objects.
568

569
570
571
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
572
573
574
        See https://github.com/vllm-project/vllm/issues/3087
        """

575
576
577
578
579
580
581
582
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
583
584
        return copy.deepcopy(self, memo=logit_processor_refs)

585
    def __repr__(self) -> str:
586
587
588
589
590
591
592
593
594
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
595
            f"seed={self.seed}, "
596
597
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
598
            f"bad_words={self.bad_words}, "
599
600
601
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
602
            f"min_tokens={self.min_tokens}, "
603
604
605
606
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
607
            f"{self.spaces_between_special_tokens}, "
608
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
609
            f"structured_outputs={self.structured_outputs}, "
610
611
            f"extra_args={self.extra_args})"
        )
612
613
614


class BeamSearchParams(
615
616
617
618
619
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
620
    """Beam search parameters for text generation."""
621

622
623
624
625
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
626
    length_penalty: float = 1.0
627
    include_stop_str_in_output: bool = False