sampling_params.py 27.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sampling parameters for text generation."""
4

5
import copy
6
import warnings
7
from dataclasses import field
8
from enum import Enum, IntEnum
9
from functools import cached_property
10
from typing import Annotated, Any
11

12
import msgspec
13
from pydantic.dataclasses import dataclass
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
from vllm.logger import init_logger
16
from vllm.logits_process import LogitsProcessor
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.v1.serial_utils import PydanticMsgspecMixin
19
20
21

logger = init_logger(__name__)

22
_SAMPLING_EPS = 1e-5
23
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25

26
27
28
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
29
    RANDOM_SEED = 2
30
31


32
33
# maybe make msgspec?
@dataclass
34
35
class StructuredOutputsParams:
    # One of these fields will be used to build a logit processor.
36
37
38
39
40
    json: str | dict | None = None
    regex: str | None = None
    choice: list[str] | None = None
    grammar: str | None = None
    json_object: bool | None = None
41
    # These are other options that can be set.
42
43
44
    disable_fallback: bool = False
    disable_any_whitespace: bool = False
    disable_additional_properties: bool = False
45
46
    whitespace_pattern: str | None = None
    structural_tag: str | None = None
47

48
    _backend: str | None = field(default=None, init=False)
49
50
51
    """CAUTION: Should only be set by Processor._validate_structured_output"""
    _backend_was_auto: bool = field(default=False, init=False)
    """CAUTION: Should only be set by Processor._validate_structured_output"""
52
53
54

    def __post_init__(self):
        """Validate that some fields are mutually exclusive."""
55
56
57
58
59
60
61
        count = sum(
            [
                self.json is not None,
                self.regex is not None,
                self.choice is not None,
                self.grammar is not None,
                self.json_object is not None,
62
                self.structural_tag is not None,
63
64
            ]
        )
65
        if count > 1:
66
            raise ValueError(
67
                "You can only use one kind of structured outputs constraint "
68
69
                f"but multiple are specified: {self.__dict__}"
            )
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def all_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
                "structural_tag",
            )
        )

    def all_non_structural_tag_constraints_none(self) -> bool:
        """
        Returns True if all structured-output constraint fields are None.
        """
        return all(
            getattr(self, field) is None
            for field in (
                "json",
                "regex",
                "choice",
                "grammar",
                "json_object",
            )
        )

102

103
104
105
106
107
108
109
110
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
    def __post_init__(self):
        warnings.warn(
            "GuidedDecodingParams is deprecated. This will be removed in "
            "v0.12.0 or v1.0.0, which ever is soonest. Please use "
            "StructuredOutputsParams instead.",
            DeprecationWarning,
111
112
            stacklevel=2,
        )
113
114
115
        return super().__post_init__()


116
117
118
119
120
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
121
    # Do not return intermediate RequestOutput
122
123
124
    FINAL_ONLY = 2


125
class SamplingParams(
126
    PydanticMsgspecMixin,
127
128
129
130
131
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
132
133
134
135
136
137
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
138

139
    n: int = 1
140
141
142
143
144
145
146
    """Number of outputs to return for the given prompt request.

    NOTE:
        `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
        are generated and streamed cumulatively per request. To see all `n`
        outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
        in `SamplingParams`."""
147
    best_of: int | None = None
148
149
150
151
    """Number of output sequences that are generated from the prompt. From
    these `best_of` sequences, the top `n` sequences are returned. `best_of`
    must be greater than or equal to `n`. By default, `best_of` is set to `n`.
    Warning, this is only supported in V0."""
152
    _real_n: int | None = None
153
    presence_penalty: float = 0.0
154
155
156
    """Penalizes new tokens based on whether they appear in the generated text
    so far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
157
    frequency_penalty: float = 0.0
158
159
160
    """Penalizes new tokens based on their frequency in the generated text so
    far. Values > 0 encourage the model to use new tokens, while values < 0
    encourage the model to repeat tokens."""
161
    repetition_penalty: float = 1.0
162
163
164
    """Penalizes new tokens based on whether they appear in the prompt and the
    generated text so far. Values > 1 encourage the model to use new tokens,
    while values < 1 encourage the model to repeat tokens."""
165
    temperature: float = 1.0
166
167
168
    """Controls the randomness of the sampling. Lower values make the model
    more deterministic, while higher values make the model more random. Zero
    means greedy sampling."""
169
    top_p: float = 1.0
170
171
    """Controls the cumulative probability of the top tokens to consider. Must
    be in (0, 1]. Set to 1 to consider all tokens."""
172
    top_k: int = 0
173
174
    """Controls the number of top tokens to consider. Set to 0 (or -1) to
    consider all tokens."""
175
    min_p: float = 0.0
176
177
178
    """Represents the minimum probability for a token to be considered,
    relative to the probability of the most likely token. Must be in [0, 1].
    Set to 0 to disable this."""
179
    seed: int | None = None
180
    """Random seed to use for the generation."""
181
    stop: str | list[str] | None = None
182
183
    """String(s) that stop the generation when they are generated. The returned
    output will not contain the stop strings."""
184
    stop_token_ids: list[int] | None = None
185
186
187
    """Token IDs that stop the generation when they are generated. The returned
    output will contain the stop tokens unless the stop tokens are special
    tokens."""
188
    ignore_eos: bool = False
189
190
    """Whether to ignore the EOS token and continue generating
    tokens after the EOS token is generated."""
191
    max_tokens: int | None = 16
192
    """Maximum number of tokens to generate per output sequence."""
193
    min_tokens: int = 0
194
195
    """Minimum number of tokens to generate per output sequence before EOS or
    `stop_token_ids` can be generated"""
196
    logprobs: int | None = None
197
198
199
200
201
202
203
    """Number of log probabilities to return per output token. When set to
    `None`, no probability is returned. If set to a non-`None` value, the
    result includes the log probabilities of the specified number of most
    likely tokens, as well as the chosen tokens. Note that the implementation
    follows the OpenAI API: The API will always return the log probability of
    the sampled token, so there may be up to `logprobs+1` elements in the
    response. When set to -1, return all `vocab_size` log probabilities."""
204
    prompt_logprobs: int | None = None
205
206
    """Number of log probabilities to return per prompt token.
    When set to -1, return all `vocab_size` log probabilities."""
207
208
209
210
211
212
    flat_logprobs: bool = False
    """Whether to return logprobs in flatten format (i.e. FlatLogprob)
    for better performance.
    NOTE: GC costs of FlatLogprobs is significantly smaller than
    list[dict[int, Logprob]]. After enabled, PromptLogprobs and
    SampleLogprobs would populated as FlatLogprobs."""
213
214
215
216
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
217
    """Whether to detokenize the output."""
218
    skip_special_tokens: bool = True
219
    """Whether to skip special tokens in the output."""
220
    spaces_between_special_tokens: bool = True
221
    """Whether to add spaces between special tokens in the output."""
222
223
224
    # `list[LogitsProcessor] | None` type. We use Any here because
    # `list[LogitsProcessor] | None` type is not supported by msgspec.
    logits_processors: Any | None = None
225
226
    """Functions that modify logits based on previously generated tokens, and
    optionally prompt tokens as a first argument."""
227
    include_stop_str_in_output: bool = False
228
    """Whether to include the stop strings in output text."""
229
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
230
231
232
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
233
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
234
235
236
237

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
238
    _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
239

240
    # Fields used to construct logits processors
241
    structured_outputs: StructuredOutputsParams | None = None
242
    """Parameters for configuring structured outputs."""
243
    guided_decoding: GuidedDecodingParams | None = None
244
    """Deprecated alias for structured_outputs."""
245
    logit_bias: dict[int, float] | None = None
246
247
    """If provided, the engine will construct a logits processor that applies
    these logit biases."""
248
    allowed_token_ids: list[int] | None = None
249
250
    """If provided, the engine will construct a logits processor which only
    retains scores for the given token ids."""
251
    extra_args: dict[str, Any] | None = None
252
253
254
    """Arbitrary additional args, that can be used by custom sampling
    implementations, plugins, etc. Not used by any in-tree sampling
    implementations."""
255

256
    # Fields used for bad words
257
    bad_words: list[str] | None = None
258
259
260
    """Words that are not allowed to be generated. More precisely, only the
    last token of a corresponding token sequence is not allowed when the next
    generated token can complete the sequence."""
261
    _bad_words_token_ids: list[list[int]] | None = None
262

263
264
    skip_reading_prefix_cache: bool = None

265
266
    @staticmethod
    def from_optional(
267
268
269
270
271
272
273
        n: int | None = 1,
        best_of: int | None = None,
        presence_penalty: float | None = 0.0,
        frequency_penalty: float | None = 0.0,
        repetition_penalty: float | None = 1.0,
        temperature: float | None = 1.0,
        top_p: float | None = 1.0,
274
        top_k: int = 0,
275
        min_p: float = 0.0,
276
277
278
279
        seed: int | None = None,
        stop: str | list[str] | None = None,
        stop_token_ids: list[int] | None = None,
        bad_words: list[str] | None = None,
280
281
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
282
        max_tokens: int | None = 16,
283
        min_tokens: int = 0,
284
285
        logprobs: int | None = None,
        prompt_logprobs: int | None = None,
286
287
288
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
289
290
        logits_processors: list[LogitsProcessor] | None = None,
        truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
291
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
292
293
294
295
296
        structured_outputs: StructuredOutputsParams | None = None,
        guided_decoding: GuidedDecodingParams | None = None,
        logit_bias: dict[int, float] | dict[str, float] | None = None,
        allowed_token_ids: list[int] | None = None,
        extra_args: dict[str, Any] | None = None,
297
    ) -> "SamplingParams":
298
        if logit_bias is not None:
299
300
            # Convert token_id to integer
            # Clamp the bias between -100 and 100 per OpenAI API spec
301
            logit_bias = {
302
                int(token): min(100.0, max(-100.0, bias))
303
304
                for token, bias in logit_bias.items()
            }
305
306
307
308
309
310
        if guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
311
312
                stacklevel=2,
            )
313
314
            structured_outputs = guided_decoding
            guided_decoding = None
315

316
317
        return SamplingParams(
            n=1 if n is None else n,
318
            best_of=best_of,
319
320
            presence_penalty=0.0 if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
321
            repetition_penalty=1.0
322
323
            if repetition_penalty is None
            else repetition_penalty,
324
325
326
327
328
329
330
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            stop=stop,
            stop_token_ids=stop_token_ids,
331
            bad_words=bad_words,
332
333
334
335
336
337
338
339
340
341
342
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
343
            output_kind=output_kind,
344
            structured_outputs=structured_outputs,
345
346
            logit_bias=logit_bias,
            allowed_token_ids=allowed_token_ids,
347
            extra_args=extra_args,
348
349
        )

350
    def __post_init__(self) -> None:
351
352
353
354
        # how we deal with `best_of`:
        # if `best_of` is not set, we default to `n`;
        # if `best_of` is set, we set `n` to `best_of`,
        # and set `_real_n` to the original `n`.
355
356
357
358
359
360
        # when we return the result, we will check
        # if we need to return `n` or `_real_n` results
        if self.best_of:
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
361
362
                    f"got n={self.n} and best_of={self.best_of}."
                )
363
364
365
            if not self._real_n:
                self._real_n = self.n
                self.n = self.best_of
366

367
        if 0 < self.temperature < _MAX_TEMP:
368
369
370
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
371
372
373
374
                self.temperature,
                _MAX_TEMP,
                _MAX_TEMP,
            )
375
            self.temperature = max(self.temperature, _MAX_TEMP)
376

377
        if self.seed == -1:
378
            self.seed = None
379

380
        if self.stop is None:
381
            self.stop = []
382
383
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
384

385
        if self.stop_token_ids is None:
386
            self.stop_token_ids = []
387
388
389
390

        if self.bad_words is None:
            self.bad_words = []

391
392
393
394
395
        if self.logprobs is True:
            self.logprobs = 1

        if self.prompt_logprobs is True:
            self.prompt_logprobs = 1
396

397
398
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
399
        if self.stop and not self.include_stop_str_in_output:
400
401
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

402
        self._verify_args()
403
404
405
406

        if self.temperature < _SAMPLING_EPS:
            # Zero temperature means greedy sampling.
            self.top_p = 1.0
407
            self.top_k = 0
408
409
            self.min_p = 0.0
            self._verify_greedy_sampling()
410

411
        # eos_token_id is added to this by the engine
412
        self._all_stop_token_ids.update(self.stop_token_ids)
413

414
415
416
417
418
419
        if self.guided_decoding is not None:
            warnings.warn(
                "guided_decoding is deprecated. This will be removed in "
                "v0.12.0 or v1.0.0, which ever is soonest. Please use "
                "structured_outputs instead.",
                DeprecationWarning,
420
421
                stacklevel=2,
            )
422
423
424
            self.structured_outputs = self.guided_decoding
            self.guided_decoding = None

425
426
427
428
429
430
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of prompt logprobs may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            self.skip_reading_prefix_cache = self.prompt_logprobs is not None

431
    def _verify_args(self) -> None:
432
        if not isinstance(self.n, int):
433
            raise ValueError(f"n must be an int, but is of type {type(self.n)}")
434
435
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
436
437
438
        if self.best_of is not None:
            if not isinstance(self.best_of, int):
                raise ValueError(
439
440
                    f"best_of must be an integer, got {type(self.best_of)}"
                )
441
            if self.best_of < 1:
442
                raise ValueError(f"best_of must be at least 1, got {self.best_of}")
443
444
445
            if self.best_of < self.n:
                raise ValueError(
                    f"best_of must be greater than or equal to n, "
446
447
                    f"got n={self.n} and best_of={self.best_of}."
                )
448
        if not -2.0 <= self.presence_penalty <= 2.0:
449
450
451
            raise ValueError(
                f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
            )
452
        if not -2.0 <= self.frequency_penalty <= 2.0:
453
454
455
            raise ValueError(
                f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
            )
456
457
458
        if self.repetition_penalty <= 0.0:
            raise ValueError(
                "repetition_penalty must be greater than zero, got "
459
460
                f"{self.repetition_penalty}."
            )
461
462
        if self.temperature < 0.0:
            raise ValueError(
463
464
                f"temperature must be non-negative, got {self.temperature}."
            )
465
466
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
467
468
        # quietly accept -1 as disabled, but prefer 0
        if self.top_k < -1:
469
470
471
            raise ValueError(
                f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
            )
472
473
        if not isinstance(self.top_k, int):
            raise TypeError(
474
475
                f"top_k must be an integer, got {type(self.top_k).__name__}"
            )
Roy's avatar
Roy committed
476
        if not 0.0 <= self.min_p <= 1.0:
477
            raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
478
        if self.max_tokens is not None and self.max_tokens < 1:
479
            raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
480
        if self.min_tokens < 0:
481
482
483
            raise ValueError(
                f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
            )
484
485
486
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
487
488
489
                f"max_tokens={self.max_tokens}, got {self.min_tokens}."
            )
        if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
490
            raise ValueError(
491
492
493
494
495
496
497
                f"logprobs must be non-negative or -1, got {self.logprobs}."
            )
        if (
            self.prompt_logprobs is not None
            and self.prompt_logprobs != -1
            and self.prompt_logprobs < 0
        ):
498
499
            raise ValueError(
                f"prompt_logprobs must be non-negative or -1, got "
500
501
502
503
504
                f"{self.prompt_logprobs}."
            )
        if self.truncate_prompt_tokens is not None and (
            self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
        ):
505
506
            raise ValueError(
                f"truncate_prompt_tokens must be an integer >= 1 or -1, "
507
508
                f"got {self.truncate_prompt_tokens}"
            )
509
510
        assert isinstance(self.stop_token_ids, list)
        if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
511
512
513
            raise ValueError(
                f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
            )
514
        assert isinstance(self.stop, list)
515
516
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
517
518
519
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
520
521
                "Set detokenize=True to use stop."
            )
522
        if self.best_of != self._real_n and self.output_kind == (
523
524
            RequestOutputKind.DELTA
        ):
525
            raise ValueError("best_of must equal n to use output_kind=DELTA")
526
527

    def _verify_greedy_sampling(self) -> None:
528
        if self.n > 1:
529
            raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
530

531
    def update_from_generation_config(
532
533
        self,
        generation_config: dict[str, Any],
534
        model_eos_token_id: int | None = None,
535
    ) -> None:
536
        """Update if there are non-default values from generation_config"""
537
538
539
540

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
541
            self._all_stop_token_ids.add(model_eos_token_id)
542

543
        # Update eos_token_id for generation
544
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
545
            # it can be either int or list of int
546
547
548
549
550
551
552
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
553
                self._all_stop_token_ids.update(eos_ids)
554
555
556
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
557

558
    def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
559
        if not self.bad_words:
560
            return
561
        self._bad_words_token_ids = []
562
563
564
565
566
567
568
        for bad_word in self.bad_words:
            # To prohibit words both at the beginning
            # and in the middle of text
            # (related to add_prefix_space tokenizer parameter)
            for add_prefix_space in [False, True]:
                prefix = " " if add_prefix_space else ""
                prompt = prefix + bad_word.lstrip()
569
570
571
                prompt_token_ids = tokenizer.encode(
                    text=prompt, add_special_tokens=False
                )
572
573
574
575

                # If no space at the beginning
                # or if prefix space produces a new word token
                if (not add_prefix_space) or (
576
577
578
579
                    add_prefix_space
                    and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
                    and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
                ):
580
581
582
                    self._bad_words_token_ids.append(prompt_token_ids)

        invalid_token_ids = [
583
584
            token_id
            for bad_words_token_ids in self._bad_words_token_ids
585
586
587
588
589
            for token_id in bad_words_token_ids
            if token_id < 0 or token_id > tokenizer.max_token_id
        ]
        if len(invalid_token_ids) > 0:
            raise ValueError(
590
                f"The model vocabulary size is {tokenizer.max_token_id + 1},"
591
592
593
                f" but the following tokens"
                f" were specified as bad: {invalid_token_ids}."
                f" All token id values should be integers satisfying:"
594
595
                f" 0 <= token_id <= {tokenizer.max_token_id}."
            )
596

597
598
599
600
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
601
602
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
603
604
        return SamplingType.RANDOM

605
    @property
606
    def all_stop_token_ids(self) -> set[int]:
607
608
        return self._all_stop_token_ids

609
    @property
610
    def bad_words_token_ids(self) -> list[list[int]] | None:
611
612
613
        # For internal use only. Backward compatibility not guaranteed
        return self._bad_words_token_ids

614
    def clone(self) -> "SamplingParams":
615
        """Deep copy, but maybe not the LogitsProcessor objects.
616

617
618
619
        LogitsProcessor objects may contain an arbitrary, nontrivial amount of
        data that is expensive to copy. However, if not copied, the processor
        needs to support parallel decoding for multiple sequences
620
621
622
        See https://github.com/vllm-project/vllm/issues/3087
        """

623
624
625
626
627
628
629
630
        logit_processor_refs = (
            None
            if self.logits_processors is None
            else {
                id(lp): lp.clone() if hasattr(lp, "clone") else lp
                for lp in self.logits_processors
            }
        )
631
632
        return copy.deepcopy(self, memo=logit_processor_refs)

633
    def __repr__(self) -> str:
634
635
636
637
638
639
640
641
642
        return (
            f"SamplingParams(n={self.n}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
643
            f"seed={self.seed}, "
644
645
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
646
            f"bad_words={self.bad_words}, "
647
648
649
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
650
            f"min_tokens={self.min_tokens}, "
651
652
653
654
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
655
            f"{self.spaces_between_special_tokens}, "
656
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
657
            f"structured_outputs={self.structured_outputs}, "
658
659
            f"extra_args={self.extra_args})"
        )
660
661
662


class BeamSearchParams(
663
664
665
666
667
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    # required for @cached_property.
    dict=True,
):  # type: ignore[call-arg]
668
    """Beam search parameters for text generation."""
669

670
671
672
673
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
674
    length_penalty: float = 1.0
675
    include_stop_str_in_output: bool = False