"tests/kernels/test_attention.py" did not exist on "a1b3de86cd6f27aeb299d45296a7409b8d2b7c0c"
sampling_params.py 20.3 KB
Newer Older
1
"""Sampling parameters for text generation."""
2
import copy
3
from enum import Enum, IntEnum
4
from functools import cached_property
5
from typing import Any, Callable, Dict, List, Optional, Set, Union
6

7
import msgspec
8
import torch
9
from typing_extensions import Annotated
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
from vllm.logger import init_logger

logger = init_logger(__name__)

15
_SAMPLING_EPS = 1e-5
16
_MAX_TEMP = 1e-2
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18

19
20
21
class SamplingType(IntEnum):
    GREEDY = 0
    RANDOM = 1
Nick Hill's avatar
Nick Hill committed
22
23
    RANDOM_SEED = 2
    BEAM = 3
24
25


26
27
28
29
30
31
32
33
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
                        Callable[[List[int], List[int], torch.Tensor],
                                 torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
34
35


36
37
38
39
40
41
42
43
44
class RequestOutputKind(Enum):
    # Return entire output so far in every RequestOutput
    CUMULATIVE = 0
    # Return only deltas in each RequestOutput
    DELTA = 1
    # Do not return intermediate RequestOuputs
    FINAL_ONLY = 2


45
46
47
48
49
class SamplingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        # required for @cached_property.
        dict=True):  # type: ignore[call-arg]
50
51
52
53
54
55
56
    """Sampling parameters for text generation.

    Overall, we follow the sampling parameters from the OpenAI text completion
    API (https://platform.openai.com/docs/api-reference/completions/create).
    In addition, we support beam search, which is not supported by OpenAI.

    Args:
57
58
59
60
61
62
        n: Number of output sequences to return for the given prompt.
        best_of: Number of output sequences that are generated from the prompt.
            From these `best_of` sequences, the top `n` sequences are returned.
            `best_of` must be greater than or equal to `n`. This is treated as
            the beam width when `use_beam_search` is True. By default, `best_of`
            is set to `n`.
63
64
65
66
67
68
69
70
        presence_penalty: Float that penalizes new tokens based on whether they
            appear in the generated text so far. Values > 0 encourage the model
            to use new tokens, while values < 0 encourage the model to repeat
            tokens.
        frequency_penalty: Float that penalizes new tokens based on their
            frequency in the generated text so far. Values > 0 encourage the
            model to use new tokens, while values < 0 encourage the model to
            repeat tokens.
ljss's avatar
ljss committed
71
        repetition_penalty: Float that penalizes new tokens based on whether
72
73
74
            they appear in the prompt and the generated text so far. Values > 1
            encourage the model to use new tokens, while values < 1 encourage
            the model to repeat tokens.
75
76
77
78
79
80
81
        temperature: Float that controls the randomness of the sampling. Lower
            values make the model more deterministic, while higher values make
            the model more random. Zero means greedy sampling.
        top_p: Float that controls the cumulative probability of the top tokens
            to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
        top_k: Integer that controls the number of top tokens to consider. Set
            to -1 to consider all tokens.
Roy's avatar
Roy committed
82
83
84
        min_p: Float that represents the minimum probability for a token to be
            considered, relative to the probability of the most likely token.
            Must be in [0, 1]. Set to 0 to disable this.
Nick Hill's avatar
Nick Hill committed
85
        seed: Random seed to use for the generation.
86
        use_beam_search: Whether to use beam search instead of sampling.
87
88
89
90
91
92
93
94
95
        length_penalty: Float that penalizes sequences based on their length.
            Used in beam search.
        early_stopping: Controls the stopping condition for beam search. It
            accepts the following values: `True`, where the generation stops as
            soon as there are `best_of` complete candidates; `False`, where an
            heuristic is applied and the generation stops when is it very
            unlikely to find better candidates; `"never"`, where the beam search
            procedure only stops when there cannot be better candidates
            (canonical beam search algorithm).
96
97
        stop: List of strings that stop the generation when they are generated.
            The returned output will not contain the stop strings.
98
99
        stop_token_ids: List of tokens that stop the generation when they are
            generated. The returned output will contain the stop tokens unless
100
            the stop tokens are special tokens.
101
102
        include_stop_str_in_output: Whether to include the stop strings in
            output text. Defaults to False.
103
104
        ignore_eos: Whether to ignore the EOS token and continue generating
            tokens after the EOS token is generated.
105
        max_tokens: Maximum number of tokens to generate per output sequence.
106
107
        min_tokens: Minimum number of tokens to generate per output sequence
            before EOS or stop_token_ids can be generated
108
        logprobs: Number of log probabilities to return per output token.
109
110
111
112
113
114
            When set to None, no probability is returned. If set to a non-None
            value, the result includes the log probabilities of the specified
            number of most likely tokens, as well as the chosen tokens.
            Note that the implementation follows the OpenAI API: The API will
            always return the log probability of the sampled token, so there
            may be up to `logprobs+1` elements in the response.
115
        prompt_logprobs: Number of log probabilities to return per prompt token.
116
        detokenize: Whether to detokenize the output. Defaults to True.
117
        skip_special_tokens: Whether to skip special tokens in the output.
118
119
        spaces_between_special_tokens: Whether to add spaces between special
            tokens in the output.  Defaults to True.
120
        logits_processors: List of functions that modify logits based on
121
122
            previously generated tokens, and optionally prompt tokens as
            a first argument.
123
124
125
        truncate_prompt_tokens: If set to an integer k, will use only the last k
            tokens from the prompt (i.e., left truncation). Defaults to None
            (i.e., no truncation).
126
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    n: int = 1
    best_of: Optional[int] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0
    repetition_penalty: float = 1.0
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1
    min_p: float = 0.0
    seed: Optional[int] = None
    use_beam_search: bool = False
    length_penalty: float = 1.0
    early_stopping: Union[bool, str] = False
    stop: Optional[Union[str, List[str]]] = None
    stop_token_ids: Optional[List[int]] = None
    ignore_eos: bool = False
    max_tokens: Optional[int] = 16
    min_tokens: int = 0
    logprobs: Optional[int] = None
    prompt_logprobs: Optional[int] = None
    # NOTE: This parameter is only exposed at the engine level for now.
    # It is not exposed in the OpenAI API server, as the OpenAI API does
    # not support returning only a list of token IDs.
    detokenize: bool = True
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    # Optional[List[LogitsProcessor]] type. We use Any here because
    # Optional[List[LogitsProcessor]] type is not supported by msgspec.
    logits_processors: Optional[Any] = None
    include_stop_str_in_output: bool = False
    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
159
    output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
160
161
162
163
164
165

    # The below fields are not supposed to be used as an input.
    # They are set in post_init.
    output_text_buffer_length: int = 0
    _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    @staticmethod
    def from_optional(
        n: Optional[int] = 1,
        best_of: Optional[int] = None,
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
        repetition_penalty: Optional[float] = 1.0,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        seed: Optional[int] = None,
        use_beam_search: bool = False,
        length_penalty: float = 1.0,
        early_stopping: Union[bool, str] = False,
        stop: Optional[Union[str, List[str]]] = None,
        stop_token_ids: Optional[List[int]] = None,
        include_stop_str_in_output: bool = False,
        ignore_eos: bool = False,
        max_tokens: Optional[int] = 16,
        min_tokens: int = 0,
        logprobs: Optional[int] = None,
        prompt_logprobs: Optional[int] = None,
        detokenize: bool = True,
        skip_special_tokens: bool = True,
        spaces_between_special_tokens: bool = True,
        logits_processors: Optional[List[LogitsProcessor]] = None,
        truncate_prompt_tokens: Optional[Annotated[int,
                                                   msgspec.Meta(ge=1)]] = None,
195
        output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    ) -> "SamplingParams":
        return SamplingParams(
            n=1 if n is None else n,
            best_of=best_of,
            presence_penalty=0.0
            if presence_penalty is None else presence_penalty,
            frequency_penalty=0.0
            if frequency_penalty is None else frequency_penalty,
            repetition_penalty=1.0
            if repetition_penalty is None else repetition_penalty,
            temperature=1.0 if temperature is None else temperature,
            top_p=1.0 if top_p is None else top_p,
            top_k=top_k,
            min_p=min_p,
            seed=seed,
            use_beam_search=use_beam_search,
            length_penalty=length_penalty,
            early_stopping=early_stopping,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_stop_str_in_output,
            ignore_eos=ignore_eos,
            max_tokens=max_tokens,
            min_tokens=min_tokens,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            detokenize=detokenize,
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
            logits_processors=logits_processors,
            truncate_prompt_tokens=truncate_prompt_tokens,
227
            output_kind=output_kind,
228
229
        )

230
231
232
    def __post_init__(self) -> None:
        self.best_of = self.best_of or self.n
        if 0 < self.temperature < _MAX_TEMP:
233
234
235
            logger.warning(
                "temperature %s is less than %s, which may cause numerical "
                "errors nan or inf in tensors. We have maxed it out to %s.",
236
237
238
                self.temperature, _MAX_TEMP, _MAX_TEMP)
            self.temperature = max(self.temperature, _MAX_TEMP)
        if self.seed == -1:
239
240
            self.seed = None
        else:
241
242
            self.seed = self.seed
        if self.stop is None:
243
            self.stop = []
244
245
        elif isinstance(self.stop, str):
            self.stop = [self.stop]
246
        else:
247
248
            self.stop = list(self.stop)
        if self.stop_token_ids is None:
249
250
            self.stop_token_ids = []
        else:
251
252
253
254
255
            self.stop_token_ids = list(self.stop_token_ids)
        self.logprobs = 1 if self.logprobs is True else self.logprobs
        self.prompt_logprobs = (1 if self.prompt_logprobs is True else
                                self.prompt_logprobs)

256
257
        # Number of characters to hold back for stop string evaluation
        # until sequence is finished.
258
        if self.stop and not self.include_stop_str_in_output:
259
260
            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1

261
262
        self._verify_args()
        if self.use_beam_search:
263
            self._verify_beam_search()
264
265
266
267
        else:
            self._verify_non_beam_search()
            if self.temperature < _SAMPLING_EPS:
                # Zero temperature means greedy sampling.
268
269
                self.top_p = 1.0
                self.top_k = -1
Roy's avatar
Roy committed
270
                self.min_p = 0.0
271
                self._verify_greedy_sampling()
272
        # eos_token_id is added to this by the engine
273
        self._all_stop_token_ids = set(self.stop_token_ids)
274
275
276
277

    def _verify_args(self) -> None:
        if self.n < 1:
            raise ValueError(f"n must be at least 1, got {self.n}.")
278
        assert isinstance(self.best_of, int)
279
280
281
        if self.best_of < self.n:
            raise ValueError(f"best_of must be greater than or equal to n, "
                             f"got n={self.n} and best_of={self.best_of}.")
282
283
284
285
286
287
        if not -2.0 <= self.presence_penalty <= 2.0:
            raise ValueError("presence_penalty must be in [-2, 2], got "
                             f"{self.presence_penalty}.")
        if not -2.0 <= self.frequency_penalty <= 2.0:
            raise ValueError("frequency_penalty must be in [-2, 2], got "
                             f"{self.frequency_penalty}.")
ljss's avatar
ljss committed
288
289
290
        if not 0.0 < self.repetition_penalty <= 2.0:
            raise ValueError("repetition_penalty must be in (0, 2], got "
                             f"{self.repetition_penalty}.")
291
292
293
294
295
296
297
298
        if self.temperature < 0.0:
            raise ValueError(
                f"temperature must be non-negative, got {self.temperature}.")
        if not 0.0 < self.top_p <= 1.0:
            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
        if self.top_k < -1 or self.top_k == 0:
            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                             f"got {self.top_k}.")
299
300
301
        if not isinstance(self.top_k, int):
            raise TypeError(
                f"top_k must be an integer, got {type(self.top_k).__name__}")
Roy's avatar
Roy committed
302
303
304
        if not 0.0 <= self.min_p <= 1.0:
            raise ValueError("min_p must be in [0, 1], got "
                             f"{self.min_p}.")
305
        if self.max_tokens is not None and self.max_tokens < 1:
306
307
            raise ValueError(
                f"max_tokens must be at least 1, got {self.max_tokens}.")
308
309
310
311
312
313
314
        if self.min_tokens < 0:
            raise ValueError(f"min_tokens must be greater than or equal to 0, "
                             f"got {self.min_tokens}.")
        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
            raise ValueError(
                f"min_tokens must be less than or equal to "
                f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
Zhuohan Li's avatar
Zhuohan Li committed
315
        if self.logprobs is not None and self.logprobs < 0:
316
317
            raise ValueError(
                f"logprobs must be non-negative, got {self.logprobs}.")
318
319
320
        if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
            raise ValueError(f"prompt_logprobs must be non-negative, got "
                             f"{self.prompt_logprobs}.")
321
322
323
324
        if (self.truncate_prompt_tokens is not None
                and self.truncate_prompt_tokens < 1):
            raise ValueError(f"truncate_prompt_tokens must be >= 1, "
                             f"got {self.truncate_prompt_tokens}")
325
        assert isinstance(self.stop, list)
326
327
        if any(not stop_str for stop_str in self.stop):
            raise ValueError("stop cannot contain an empty string.")
328
329
330
331
        if self.stop and not self.detokenize:
            raise ValueError(
                "stop strings are only supported when detokenize is True. "
                "Set detokenize=True to use stop.")
332
333
334
        if self.best_of != self.n and self.output_kind == (
                RequestOutputKind.DELTA):
            raise ValueError("best_of must equal n to use output_kind=DELTA")
335

336
    def _verify_beam_search(self) -> None:
337
338
339
        if self.best_of == 1:
            raise ValueError("best_of must be greater than 1 when using beam "
                             f"search. Got {self.best_of}.")
340
        if self.temperature > _SAMPLING_EPS:
341
            raise ValueError("temperature must be 0 when using beam search.")
342
        if self.top_p < 1.0 - _SAMPLING_EPS:
343
344
345
            raise ValueError("top_p must be 1 when using beam search.")
        if self.top_k != -1:
            raise ValueError("top_k must be -1 when using beam search.")
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        if self.early_stopping not in [True, False, "never"]:
            raise ValueError(
                f"early_stopping must be True, False, or 'never', "
                f"got {self.early_stopping}.")

    def _verify_non_beam_search(self) -> None:
        if self.early_stopping is not False:
            raise ValueError("early_stopping is not effective and must be "
                             "False when not using beam search.")
        if (self.length_penalty < 1.0 - _SAMPLING_EPS
                or self.length_penalty > 1.0 + _SAMPLING_EPS):
            raise ValueError(
                "length_penalty is not effective and must be the "
                "default value of 1.0 when not using beam search.")
360
361

    def _verify_greedy_sampling(self) -> None:
362
        assert isinstance(self.best_of, int)
363
364
365
        if self.best_of > 1:
            raise ValueError("best_of must be 1 when using greedy sampling."
                             f"Got {self.best_of}.")
366

367
    def update_from_generation_config(
368
369
370
            self,
            generation_config: Dict[str, Any],
            model_eos_token_id: Optional[int] = None) -> None:
371
        """Update if there are non-default values from generation_config"""
372
373
374
375

        if model_eos_token_id is not None:
            # Add the eos token id into the sampling_params to support
            # min_tokens processing.
376
            self._all_stop_token_ids.add(model_eos_token_id)
377

378
        # Update eos_token_id for generation
379
        if (eos_ids := generation_config.get("eos_token_id")) is not None:
380
            # it can be either int or list of int
381
382
383
384
385
386
387
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
            if model_eos_token_id is not None:
                # We don't need to include the primary eos_token_id in
                # stop_token_ids since it's handled separately for stopping
                # purposes.
                eos_ids.discard(model_eos_token_id)
            if eos_ids:
388
                self._all_stop_token_ids.update(eos_ids)
389
390
391
                if not self.ignore_eos:
                    eos_ids.update(self.stop_token_ids)
                    self.stop_token_ids = list(eos_ids)
392

393
394
395
396
397
398
    @cached_property
    def sampling_type(self) -> SamplingType:
        if self.use_beam_search:
            return SamplingType.BEAM
        if self.temperature < _SAMPLING_EPS:
            return SamplingType.GREEDY
Nick Hill's avatar
Nick Hill committed
399
400
        if self.seed is not None:
            return SamplingType.RANDOM_SEED
401
402
        return SamplingType.RANDOM

403
404
405
406
    @property
    def all_stop_token_ids(self) -> Set[int]:
        return self._all_stop_token_ids

407
408
409
410
411
412
413
414
415
416
417
418
419
420
    def clone(self) -> "SamplingParams":
        """Deep copy excluding LogitsProcessor objects.

        LogitsProcessor objects are excluded because they may contain an
        arbitrary, nontrivial amount of data.
        See https://github.com/vllm-project/vllm/issues/3087
        """

        logit_processor_refs = None if self.logits_processors is None else {
            id(lp): lp
            for lp in self.logits_processors
        }
        return copy.deepcopy(self, memo=logit_processor_refs)

421
    def __repr__(self) -> str:
422
423
424
425
426
427
428
429
430
431
        return (
            f"SamplingParams(n={self.n}, "
            f"best_of={self.best_of}, "
            f"presence_penalty={self.presence_penalty}, "
            f"frequency_penalty={self.frequency_penalty}, "
            f"repetition_penalty={self.repetition_penalty}, "
            f"temperature={self.temperature}, "
            f"top_p={self.top_p}, "
            f"top_k={self.top_k}, "
            f"min_p={self.min_p}, "
Nick Hill's avatar
Nick Hill committed
432
            f"seed={self.seed}, "
433
434
435
436
437
438
439
440
            f"use_beam_search={self.use_beam_search}, "
            f"length_penalty={self.length_penalty}, "
            f"early_stopping={self.early_stopping}, "
            f"stop={self.stop}, "
            f"stop_token_ids={self.stop_token_ids}, "
            f"include_stop_str_in_output={self.include_stop_str_in_output}, "
            f"ignore_eos={self.ignore_eos}, "
            f"max_tokens={self.max_tokens}, "
441
            f"min_tokens={self.min_tokens}, "
442
443
444
445
            f"logprobs={self.logprobs}, "
            f"prompt_logprobs={self.prompt_logprobs}, "
            f"skip_special_tokens={self.skip_special_tokens}, "
            "spaces_between_special_tokens="
446
447
            f"{self.spaces_between_special_tokens}, "
            f"truncate_prompt_tokens={self.truncate_prompt_tokens})")