params.py 16.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
4
from typing import TYPE_CHECKING, Any, Literal, TypeVar
5
6
7
8

from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
9
from vllm.multimodal.media.connector import merge_media_io_kwargs
10
11
12
13
14
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
15
16

    from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
17
18
19
else:
    torch = LazyLoader("torch", globals(), "torch")

20
21
    ChatTemplateContentFormatOption = object

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
logger = init_logger(__name__)


_S = TypeVar("_S", list[int], "torch.Tensor")


def merge_kwargs(
    defaults: dict[str, Any] | None,
    overrides: dict[str, Any] | None,
    /,
    *,
    unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
    if defaults is None:
        defaults = {}
    if overrides is None:
        overrides = {}

    return defaults | {k: v for k, v in overrides.items() if v not in unset_values}


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def recursively_merge_kwargs(
    defaults: dict[str, Any] | None,
    overrides: dict[str, Any] | None,
    /,
    *,
    unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
    if defaults is None:
        defaults = {}
    if overrides is None:
        overrides = {}

    merged = dict(defaults)

    for k, v in overrides.items():
        if v in unset_values:
            continue

        if k in merged and isinstance(merged[k], dict) and isinstance(v, dict):
            merged[k] = recursively_merge_kwargs(
                merged[k], v, unset_values=unset_values
            )
        else:
            merged[k] = v

    return merged


71
72
73
74
75
76
77
@dataclass(frozen=True)
class ChatParams:
    """Configuration to control how to parse chat messages."""

    chat_template: str | None = None
    """The chat template to apply."""

78
    chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
79
80
81
82
83
    """The format of the chat template."""

    chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
    """The kwargs to pass to the chat template."""

84
85
86
    media_io_kwargs: dict[str, dict[str, Any]] | None = None
    """Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""

87
88
89
    mm_processor_kwargs: dict[str, Any] | None = None
    """The kwargs to pass to the multi-modal processor."""

90
91
92
93
    def with_defaults(
        self,
        default_chat_template_kwargs: dict[str, Any] | None = None,
        default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
94
        default_mm_processor_kwargs: dict[str, Any] | None = None,
95
    ):
96
97
98
99
100
        if (
            not default_chat_template_kwargs
            and not default_media_io_kwargs
            and not default_mm_processor_kwargs
        ):
101
102
103
104
105
106
107
108
109
            return self

        return ChatParams(
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                default_chat_template_kwargs,
                self.chat_template_kwargs,
            ),
110
111
112
113
            media_io_kwargs=merge_media_io_kwargs(
                default_media_io_kwargs,
                self.media_io_kwargs,
            ),
114
115
116
117
            mm_processor_kwargs=recursively_merge_kwargs(
                default_mm_processor_kwargs,
                self.mm_processor_kwargs,
            ),
118
119
120
121
122
123
        )

    def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.apply_chat_template`."""
        return merge_kwargs(
            self.chat_template_kwargs,
124
            dict(chat_template=self.chat_template, return_dict=False),
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        )


@dataclass(frozen=True)
class TokenizeParams:
    """Configuration to control how prompts are tokenized."""

    max_total_tokens: int | None
    """
    Maximum allowed number of input + output tokens.
    
    Usually, this refers to the model's context length.
    """

    max_output_tokens: int = 0
    """Maximum requested number of output tokens."""

    pad_prompt_tokens: int | None = None
    """
    Number of tokens to pad to:
    - `None` means no padding.
    - `-1` maps to `max_input_tokens`.
    """

    truncate_prompt_tokens: int | None = None
    """
    Number of tokens to keep:
    - `None` means no truncation.
    - `-1` maps to `max_input_tokens`.
    """

156
157
158
159
160
161
162
163
    truncation_side: Literal["left", "right"] | None = None
    """
    Which side to truncate from when ``truncate_prompt_tokens`` is active:
    - ``"right"`` keeps the first N tokens (truncate from the end).
    - ``"left"``  keeps the last  N tokens (truncate from the start).
    - ``None``    falls back to the tokenizer default.
    """

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    do_lower_case: bool = False
    """Whether to normalize text to lower case before tokenization."""

    add_special_tokens: bool = True
    """Whether to add special tokens."""

    needs_detokenization: bool = False
    """
    Whether the tokenized prompt needs to contain the original text.

    Not to be confused with `SamplingParams.detokenize` which deals
    with the output generated by the model.
    """

    max_total_tokens_param: str = "max_total_tokens"
    """Override this to edit the message for validation errors."""

    max_output_tokens_param: str = "max_output_tokens"
    """Override this to edit the message for validation errors."""

    truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
    """Override this to edit the message for validation errors."""

    @property
    def max_input_tokens(self) -> int | None:
        """Maximum allowed number of input tokens."""
        if self.max_total_tokens is None:
            return None

        return self.max_total_tokens - self.max_output_tokens

    def __post_init__(self) -> None:
        max_total_tokens = self.max_total_tokens
        max_output_tokens = self.max_output_tokens
        max_input_tokens = self.max_input_tokens
        truncate_prompt_tokens = self.truncate_prompt_tokens

        if (
            max_output_tokens is not None
            and max_total_tokens is not None
            and max_output_tokens > max_total_tokens
        ):
            raise VLLMValidationError(
                f"{self.max_output_tokens_param}={max_output_tokens}"
                f"cannot be greater than "
                f"{self.max_total_tokens_param}={max_total_tokens=}. "
                f"Please request fewer output tokens.",
                parameter=self.max_output_tokens_param,
                value=max_output_tokens,
            )

        if (
            max_input_tokens is not None
            and truncate_prompt_tokens is not None
            and truncate_prompt_tokens > max_input_tokens
        ):
            raise VLLMValidationError(
                f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
                f"cannot be greater than {self.max_total_tokens_param} - "
                f"{self.max_output_tokens_param} = {max_input_tokens}. "
                f"Please request a smaller truncation size.",
                parameter=self.truncate_prompt_tokens_param,
                value=truncate_prompt_tokens,
            )

229
    def with_kwargs(self, **tokenization_kwargs: Any):
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
        pad_prompt_tokens = tokenization_kwargs.pop(
            "pad_prompt_tokens", self.pad_prompt_tokens
        )
        truncate_prompt_tokens = tokenization_kwargs.pop(
            "truncate_prompt_tokens", self.truncate_prompt_tokens
        )
        do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
        add_special_tokens = tokenization_kwargs.pop(
            "add_special_tokens", self.add_special_tokens
        )
        needs_detokenization = tokenization_kwargs.pop(
            "needs_detokenization", self.needs_detokenization
        )

        # https://huggingface.co/docs/transformers/en/pad_truncation
        if padding := tokenization_kwargs.pop("padding", None):
            if padding == "max_length":
                pad_prompt_tokens = max_length
            elif padding in (False, "do_not_pad"):
                pad_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["padding"] = padding

        if truncation := tokenization_kwargs.pop("truncation", None):
            if truncation in (True, "longest_first"):
                truncate_prompt_tokens = max_length
            elif truncation in (False, "do_not_truncate"):
                truncate_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["truncation"] = truncation

        if tokenization_kwargs:
            logger.warning(
                "The following tokenization arguments are not supported "
                "by vLLM Renderer and will be ignored: %s",
                tokenization_kwargs,
            )

        max_total_tokens = self.max_total_tokens

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=(
                0
                if max_total_tokens is None or max_length is None
                else max_total_tokens - max_length
            ),
            pad_prompt_tokens=pad_prompt_tokens,
            truncate_prompt_tokens=truncate_prompt_tokens,
282
            truncation_side=self.truncation_side,
283
284
285
286
287
288
289
290
291
292
            do_lower_case=do_lower_case,
            add_special_tokens=add_special_tokens,
            needs_detokenization=needs_detokenization,
        )

    def get_encode_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.encode`."""
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens
293
294
295
296
        elif max_length is None and self.max_input_tokens is not None:
            # This prevents tokenization from taking up more resources than necessary
            # while still failing `self._token_len_check` as expected by users
            max_length = self.max_input_tokens + 1
297

298
299
300
301
302
303
304
305
306
307
        # Left-side truncation requires the full token sequence so we can
        # slice from the end in _token_truncation.  Disable HF-level
        # truncation (which would incorrectly truncate from the right for
        # pooling models) and let _token_truncation handle it.
        if self.truncation_side == "left":
            return dict(
                truncation=False,
                add_special_tokens=self.add_special_tokens,
            )

308
        return dict(
309
            truncation=max_length is not None,
310
311
312
313
            max_length=max_length,
            add_special_tokens=self.add_special_tokens,
        )

314
315
316
317
318
319
320
321
322
323
324
325
326
    def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply length checks to prompt text if necessary."""
        max_input_tokens = self.max_input_tokens
        if max_input_tokens is None:
            return text

        if self.truncate_prompt_tokens is None and tokenizer is not None:
            max_input_chars = max_input_tokens * tokenizer.max_chars_per_token

            if len(text) > max_input_chars:
                # To save resources, fail the request outright without even
                # attempting tokenization
                raise VLLMValidationError(
327
328
329
330
331
332
333
334
                    f"This model's maximum context length is "
                    f"{self.max_total_tokens} tokens. However, you requested "
                    f"{self.max_output_tokens} output tokens and your prompt "
                    f"contains {len(text)} characters (more than "
                    f"{max_input_chars} characters, which is the upper bound "
                    f"for {max_input_tokens} input tokens). "
                    f"Please reduce the length of the input prompt or the "
                    f"number of requested output tokens.",
335
336
337
                    parameter="input_text",
                    value=len(text),
                )
338
339
340

        return text

341
342
343
344
    def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply lowercase to prompt text if necessary."""
        return text.lower() if self.do_lower_case else text

345
346
    def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply all validators to prompt text."""
347
348
349
350
        for validator in (
            self._text_len_check,
            self._text_lowercase,
        ):
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
            text = validator(tokenizer, text)

        return text

    def apply_pre_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TextPrompt,
    ) -> TextPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run before tokenization occurs.
        """
        prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])

        return prompt

370
371
    def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply padding to prompt tokens if necessary."""
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        pad_length = self.pad_prompt_tokens
        if pad_length is not None and pad_length < 0:
            pad_length = self.max_input_tokens

        if pad_length is None or pad_length <= len(tokens):
            return tokens

        if tokenizer is None:
            raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
        if not isinstance(tokens, list):
            raise ValueError("Cannot pad tokens for embedding inputs")

        return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))

386
387
    def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply truncation to prompt tokens if necessary."""
388
389
390
391
392
393
394
395
396
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens

        if max_length is None or max_length >= len(tokens):
            return tokens
        if max_length == 0:
            return tokens[:0]

397
398
399
400
        side = self.truncation_side or (
            tokenizer.truncation_side if tokenizer is not None else None
        )
        if side == "left":
401
402
403
404
            return tokens[-max_length:]

        return tokens[:max_length]

405
406
    def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply length checks to prompt tokens if necessary."""
407
        max_input_tokens = self.max_input_tokens
408
409
        if max_input_tokens is None:
            return tokens
410

411
        if len(tokens) > max_input_tokens:
412
413
414
415
416
417
            token_count = len(tokens)
            # The tokenizer may have truncated the prompt to
            # max_input_tokens + 1 (see get_encode_kwargs), so the
            # actual prompt length could be larger.
            qualifier = "at least " if token_count == max_input_tokens + 1 else ""
            total = token_count + self.max_output_tokens
418
            raise VLLMValidationError(
419
420
421
422
423
424
425
                f"This model's maximum context length is "
                f"{self.max_total_tokens} tokens. However, you requested "
                f"{self.max_output_tokens} output tokens and your prompt "
                f"contains {qualifier}{token_count} input tokens, "
                f"for a total of {qualifier}{total} tokens. "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
426
                parameter="input_tokens",
427
                value=token_count,
428
429
430
431
432
433
434
            )

        return tokens

    def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply all validators to a token sequence."""
        for validator in (
435
436
437
            self._token_padding,
            self._token_truncation,
            self._token_len_check,
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        ):
            tokens = validator(tokenizer, tokens)

        return tokens

    def apply_post_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TokensPrompt | EmbedsPrompt,
    ) -> TokensPrompt | EmbedsPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run after tokenization occurs.
        """
        if "prompt_token_ids" in prompt:
            prompt["prompt_token_ids"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_token_ids"],  # type: ignore[typeddict-item]
            )
        if "prompt_embeds" in prompt:
            prompt["prompt_embeds"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_embeds"],  # type: ignore[typeddict-item]
            )

        return prompt