sampling_params.py 26.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import warnings
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Annotated, Any
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
from vllm.logger import init_logger
16
from vllm.logits_process import LogitsProcessor
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
19
20

logger = init_logger(__name__)

21
_SAMPLING_EPS = 1e-5
22
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24

25
26
27
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
28
    RANDOM_SEED = 2
29
30


31
32
# maybe make msgspec?
@dataclass
33
34
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
35
36
37
38
39
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
40
    # These are other options that can be set.
41
42
43
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
44
45
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
46

47
    _backend: str | None = field(default=None, init=False)
48
49
50
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
51
52
53

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
54
55
56
57
58
59
60
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
61
                self.structural_tag is not None,
62
63
            ]
        )
64
        if count > 1:
65
            raise ValueError(
66
                "You can only use one kind of structured outputs constraint "
67
68
                f"but multiple are specified: {self.__dict__}"
            )
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

101

102
103
104
105
106
107
108
109
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
    def __post_init__(self):
        warnings.warn(
            "GuidedDecodingParams is deprecated. This will be removed in "
            "v0.12.0 or v1.0.0, which ever is soonest. Please use "
            "StructuredOutputsParams instead.",
            DeprecationWarning,
110
111
            stacklevel=2,
        )
112
113
114
        return super().__post_init__()


115
116
117
118
119
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
120
    # Do not return intermediate RequestOutput
121
122
123
    FINAL_ONLY = 2


124
class SamplingParams(
125
126
127
128
129
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
130
131
132
133
134
135
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
136

137
    n: int = 1
138
139
140
141
142
143
144
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
145
    best_of: int | None = None
146
147
148
149
    """Number of output sequences that are generated from the prompt. From
    these `best_of` sequences, the top `n` sequences are returned. `best_of`
    must be greater than or equal to `n`. By default, `best_of` is set to `n`.
    Warning, this is only supported in V0."""
150
    _real_n: int | None = None
151
    presence_penalty: float = 0.0
152
153
154
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
155
    frequency_penalty: float = 0.0
156
157
158
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
159
    repetition_penalty: float = 1.0
160
161
162
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
163
    temperature: float = 1.0
164
165
166
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
167
    top_p: float = 1.0
168
169
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
170
    top_k: int = 0
171
172
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
173
    min_p: float = 0.0
174
175
176
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
177
    seed: int | None = None
178
    """Random seed to use for the generation."""
179
    stop: str | list[str] | None = None
180
181
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
182
    stop_token_ids: list[int] | None = None
183
184
185
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
186
    ignore_eos: bool = False
187
188
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
189
    max_tokens: int | None = 16
190
    """Maximum number of tokens to generate per output sequence."""
191
    min_tokens: int = 0
192
193
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
194
    logprobs: int | None = None
195
196
197
198
199
200
201
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
202
    prompt_logprobs: int | None = None
203
204
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
205
206
207
208
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
209
    """Whether to detokenize the output."""
210
    skip_special_tokens: bool = True
211
    """Whether to skip special tokens in the output."""
212
    spaces_between_special_tokens: bool = True
213
    """Whether to add spaces between special tokens in the output."""
214
215
216
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
217
218
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
219
    include_stop_str_in_output: bool = False
220
    """Whether to include the stop strings in output text."""
221
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
222
223
224
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
225
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
226
227
228
229

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
230
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
231

232
    # Fields used to construct logits processors
233
    structured_outputs: StructuredOutputsParams | None = None
234
    """Parameters for configuring structured outputs."""
235
    guided_decoding: GuidedDecodingParams | None = None
236
    """Deprecated alias for structured_outputs."""
237
    logit_bias: dict[int, float] | None = None
238
239
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
240
    allowed_token_ids: list[int] | None = None
241
242
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
243
    extra_args: dict[str, Any] | None = None
244
245
246
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
247

248
    # Fields used for bad words
249
    bad_words: list[str] | None = None
250
251
252
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
253
    _bad_words_token_ids: list[list[int]] | None = None
254

255
256
    @staticmethod
    def from_optional(
257
258
259
260
261
262
263
        n: int | None = 1,
        best_of: int | None = None,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
264
        top_k: int = 0,
265
        min_p: float = 0.0,
266
267
268
269
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
270
271
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
272
        max_tokens: int | None = 16,
273
        min_tokens: int = 0,
274
275
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
276
277
278
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
279
280
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
281
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
282
283
284
285
286
        structured_outputs: StructuredOutputsParams | None = None,
        guided_decoding: GuidedDecodingParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
287
    ) -> "SamplingParams":
288
        if logit_bias is not None:
289
290
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
291
            logit_bias = {
292
                int(token): min(100.0, max(-100.0, bias))
293
294
                for token, bias in logit_bias.items()
            }
295
296
297
298
299
300
        if guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
301
302
                stacklevel=2,
            )
303
304
            structured_outputs = guided_decoding
            guided_decoding = None
305

306
307
        return SamplingParams(
            n=1 if n is None else n,
308
            best_of=best_of,
309
310
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
311
            repetition_penalty=1.0
312
313
            if repetition_penalty is None
            else repetition_penalty,
314
315
316
317
318
319
320
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
321
            bad_words=bad_words,
322
323
324
325
326
327
328
329
330
331
332
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
333
            output_kind=output_kind,
334
            structured_outputs=structured_outputs,
335
336
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
337
            extra_args=extra_args,
338
339
        )

340
    def __post_init__(self) -> None:
341
342
343
344
        # how we deal with `best_of`:
        # if `best_of` is not set, we default to `n`;
        # if `best_of` is set, we set `n` to `best_of`,
        # and set `_real_n` to the original `n`.
345
346
347
348
349
350
        # when we return the result, we will check
        # if we need to return `n` or `_real_n` results
        if self.best_of:
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
351
352
                    f"got n={self.n} and best_of={self.best_of}."
                )
353
354
355
            if not self._real_n:
                self._real_n = self.n
                self.n = self.best_of
356

357
        if 0 < self.temperature < _MAX_TEMP:
358
359
360
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
361
362
363
364
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
365
            self.temperature = max(self.temperature, _MAX_TEMP)
366

367
        if self.seed == -1:
368
            self.seed = None
369

370
        if self.stop is None:
371
            self.stop = []
372
373
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
374

375
        if self.stop_token_ids is None:
376
            self.stop_token_ids = []
377
378
379
380

        if self.bad_words is None:
            self.bad_words = []

381
382
383
384
385
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
386

387
388
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
389
        if self.stop and not self.include_stop_str_in_output:
390
391
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

392
        self._verify_args()
393
394
395
396

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
397
            self.top_k = 0
398
399
            self.min_p = 0.0
            self._verify_greedy_sampling()
400

401
        # eos_token_id is added to this by the engine
402
        self._all_stop_token_ids.update(self.stop_token_ids)
403

404
405
406
407
408
409
        if self.guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
410
411
                stacklevel=2,
            )
412
413
414
            self.structured_outputs = self.guided_decoding
            self.guided_decoding = None

415
    def _verify_args(self) -> None:
416
        if not isinstance(self.n, int):
417
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
418
419
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
420
421
422
        if self.best_of is not None:
            if not isinstance(self.best_of, int):
                raise ValueError(
423
424
                    f"best_of must be an integer, got {type(self.best_of)}"
                )
425
            if self.best_of < 1:
426
                raise ValueError(f"best_of must be at least 1, got {self.best_of}")
427
428
429
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
430
431
                    f"got n={self.n} and best_of={self.best_of}."
                )
432
        if not -2.0 <= self.presence_penalty <= 2.0:
433
434
435
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
436
        if not -2.0 <= self.frequency_penalty <= 2.0:
437
438
439
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
440
441
442
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
443
444
                f"{self.repetition_penalty}."
            )
445
446
        if self.temperature < 0.0:
            raise ValueError(
447
448
                f"temperature must be non-negative, got {self.temperature}."
            )
449
450
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
451
452
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
453
454
455
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
456
457
        if not isinstance(self.top_k, int):
            raise TypeError(
458
459
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
460
        if not 0.0 <= self.min_p <= 1.0:
461
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
462
        if self.max_tokens is not None and self.max_tokens < 1:
463
            raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
464
        if self.min_tokens < 0:
465
466
467
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
468
469
470
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
471
472
473
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
474
            raise ValueError(
475
476
477
478
479
480
481
                f"logprobs must be non-negative or -1, got {self.logprobs}."
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
482
483
            raise ValueError(
                f"prompt_logprobs must be non-negative or -1, got "
484
485
486
487
488
                f"{self.prompt_logprobs}."
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
489
490
            raise ValueError(
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
491
492
                f"got {self.truncate_prompt_tokens}"
            )
493
494
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
495
496
497
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
498
        assert isinstance(self.stop, list)
499
500
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
501
502
503
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
504
505
                "Set detokenize=True to use stop."
            )
506
        if self.best_of != self._real_n and self.output_kind == (
507
508
            RequestOutputKind.DELTA
        ):
509
            raise ValueError("best_of must equal n to use output_kind=DELTA")
510
511

    def _verify_greedy_sampling(self) -> None:
512
        if self.n > 1:
513
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
514

515
    def update_from_generation_config(
516
517
        self,
        generation_config: dict[str, Any],
518
        model_eos_token_id: int | None = None,
519
    ) -> None:
520
        """Update if there are non-default values from generation_config"""
521
522
523
524

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
525
            self._all_stop_token_ids.add(model_eos_token_id)
526

527
        # Update eos_token_id for generation
528
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
529
            # it can be either int or list of int
530
531
532
533
534
535
536
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
537
                self._all_stop_token_ids.update(eos_ids)
538
539
540
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
541

542
    def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
543
        if not self.bad_words:
544
            return
545
        self._bad_words_token_ids = []
546
547
548
549
550
551
552
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
553
554
555
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
556
557
558
559

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
560
561
562
563
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
564
565
566
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
567
568
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
569
570
571
572
573
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
            raise ValueError(
574
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
575
576
577
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
578
579
                f" 0 <= token_id <= {tokenizer.max_token_id}."
            )
580

581
582
583
584
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
585
586
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
587
588
        return SamplingType.RANDOM

589
    @property
590
    def all_stop_token_ids(self) -> set[int]:
591
592
        return self._all_stop_token_ids

593
    @property
594
    def bad_words_token_ids(self) -> list[list[int]] | None:
595
596
597
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

598
    def clone(self) -> "SamplingParams":
599
        """Deep copy, but maybe not the LogitsProcessor objects.
600

601
602
603
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
604
605
606
        See https://github.com/vllm-project/vllm/issues/3087
        """

607
608
609
610
611
612
613
614
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
615
616
        return copy.deepcopy(self, memo=logit_processor_refs)

617
    def __repr__(self) -> str:
618
619
620
621
622
623
624
625
626
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
627
            f"seed={self.seed}, "
628
629
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
630
            f"bad_words={self.bad_words}, "
631
632
633
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
634
            f"min_tokens={self.min_tokens}, "
635
636
637
638
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
639
            f"{self.spaces_between_special_tokens}, "
640
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
641
            f"structured_outputs={self.structured_outputs}, "
642
643
            f"extra_args={self.extra_args})"
        )
644
645
646


class BeamSearchParams(
647
648
649
650
651
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
652
    """Beam search parameters for text generation."""
653

654
655
656
657
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
658
    length_penalty: float = 1.0
659
    include_stop_str_in_output: bool = False