params.py 15.8 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar

from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
9
from vllm.multimodal.media.connector import merge_media_io_kwargs
10
11
12
13
14
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
15
16

    from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
17
18
19
else:
    torch = LazyLoader("torch", globals(), "torch")

20
21
    ChatTemplateContentFormatOption = object

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
logger = init_logger(__name__)


_S = TypeVar("_S", list[int], "torch.Tensor")


def merge_kwargs(
    defaults: dict[str, Any] | None,
    overrides: dict[str, Any] | None,
    /,
    *,
    unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
    if defaults is None:
        defaults = {}
    if overrides is None:
        overrides = {}

    return defaults | {k: v for k, v in overrides.items() if v not in unset_values}


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def recursively_merge_kwargs(
    defaults: dict[str, Any] | None,
    overrides: dict[str, Any] | None,
    /,
    *,
    unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
    if defaults is None:
        defaults = {}
    if overrides is None:
        overrides = {}

    merged = dict(defaults)

    for k, v in overrides.items():
        if v in unset_values:
            continue

        if k in merged and isinstance(merged[k], dict) and isinstance(v, dict):
            merged[k] = recursively_merge_kwargs(
                merged[k], v, unset_values=unset_values
            )
        else:
            merged[k] = v

    return merged


71
72
73
74
75
76
77
@dataclass(frozen=True)
class ChatParams:
    """Configuration to control how to parse chat messages."""

    chat_template: str | None = None
    """The chat template to apply."""

78
    chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
79
80
81
82
83
    """The format of the chat template."""

    chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
    """The kwargs to pass to the chat template."""

84
85
86
    media_io_kwargs: dict[str, dict[str, Any]] | None = None
    """Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""

87
88
89
    mm_processor_kwargs: dict[str, Any] | None = None
    """The kwargs to pass to the multi-modal processor."""

90
91
92
93
    def with_defaults(
        self,
        default_chat_template_kwargs: dict[str, Any] | None = None,
        default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
94
        default_mm_processor_kwargs: dict[str, Any] | None = None,
95
    ):
96
97
98
99
100
        if (
            not default_chat_template_kwargs
            and not default_media_io_kwargs
            and not default_mm_processor_kwargs
        ):
101
102
103
104
105
106
107
108
109
            return self

        return ChatParams(
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                default_chat_template_kwargs,
                self.chat_template_kwargs,
            ),
110
111
112
113
            media_io_kwargs=merge_media_io_kwargs(
                default_media_io_kwargs,
                self.media_io_kwargs,
            ),
114
115
116
117
            mm_processor_kwargs=recursively_merge_kwargs(
                default_mm_processor_kwargs,
                self.mm_processor_kwargs,
            ),
118
119
120
121
122
123
        )

    def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.apply_chat_template`."""
        return merge_kwargs(
            self.chat_template_kwargs,
124
            dict(chat_template=self.chat_template, return_dict=False),
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        )


@dataclass(frozen=True)
class TokenizeParams:
    """Configuration to control how prompts are tokenized."""

    max_total_tokens: int | None
    """
    Maximum allowed number of input + output tokens.
    
    Usually, this refers to the model's context length.
    """

    max_output_tokens: int = 0
    """Maximum requested number of output tokens."""

    pad_prompt_tokens: int | None = None
    """
    Number of tokens to pad to:
    - `None` means no padding.
    - `-1` maps to `max_input_tokens`.
    """

    truncate_prompt_tokens: int | None = None
    """
    Number of tokens to keep:
    - `None` means no truncation.
    - `-1` maps to `max_input_tokens`.
    """

    do_lower_case: bool = False
    """Whether to normalize text to lower case before tokenization."""

    add_special_tokens: bool = True
    """Whether to add special tokens."""

    needs_detokenization: bool = False
    """
    Whether the tokenized prompt needs to contain the original text.

    Not to be confused with `SamplingParams.detokenize` which deals
    with the output generated by the model.
    """

    max_total_tokens_param: str = "max_total_tokens"
    """Override this to edit the message for validation errors."""

    max_output_tokens_param: str = "max_output_tokens"
    """Override this to edit the message for validation errors."""

    truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
    """Override this to edit the message for validation errors."""

    @property
    def max_input_tokens(self) -> int | None:
        """Maximum allowed number of input tokens."""
        if self.max_total_tokens is None:
            return None

        return self.max_total_tokens - self.max_output_tokens

    def __post_init__(self) -> None:
        max_total_tokens = self.max_total_tokens
        max_output_tokens = self.max_output_tokens
        max_input_tokens = self.max_input_tokens
        truncate_prompt_tokens = self.truncate_prompt_tokens

        if (
            max_output_tokens is not None
            and max_total_tokens is not None
            and max_output_tokens > max_total_tokens
        ):
            raise VLLMValidationError(
                f"{self.max_output_tokens_param}={max_output_tokens}"
                f"cannot be greater than "
                f"{self.max_total_tokens_param}={max_total_tokens=}. "
                f"Please request fewer output tokens.",
                parameter=self.max_output_tokens_param,
                value=max_output_tokens,
            )

        if (
            max_input_tokens is not None
            and truncate_prompt_tokens is not None
            and truncate_prompt_tokens > max_input_tokens
        ):
            raise VLLMValidationError(
                f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
                f"cannot be greater than {self.max_total_tokens_param} - "
                f"{self.max_output_tokens_param} = {max_input_tokens}. "
                f"Please request a smaller truncation size.",
                parameter=self.truncate_prompt_tokens_param,
                value=truncate_prompt_tokens,
            )

221
    def with_kwargs(self, **tokenization_kwargs: Any):
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
        pad_prompt_tokens = tokenization_kwargs.pop(
            "pad_prompt_tokens", self.pad_prompt_tokens
        )
        truncate_prompt_tokens = tokenization_kwargs.pop(
            "truncate_prompt_tokens", self.truncate_prompt_tokens
        )
        do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
        add_special_tokens = tokenization_kwargs.pop(
            "add_special_tokens", self.add_special_tokens
        )
        needs_detokenization = tokenization_kwargs.pop(
            "needs_detokenization", self.needs_detokenization
        )

        # https://huggingface.co/docs/transformers/en/pad_truncation
        if padding := tokenization_kwargs.pop("padding", None):
            if padding == "max_length":
                pad_prompt_tokens = max_length
            elif padding in (False, "do_not_pad"):
                pad_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["padding"] = padding

        if truncation := tokenization_kwargs.pop("truncation", None):
            if truncation in (True, "longest_first"):
                truncate_prompt_tokens = max_length
            elif truncation in (False, "do_not_truncate"):
                truncate_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["truncation"] = truncation

        if tokenization_kwargs:
            logger.warning(
                "The following tokenization arguments are not supported "
                "by vLLM Renderer and will be ignored: %s",
                tokenization_kwargs,
            )

        max_total_tokens = self.max_total_tokens

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=(
                0
                if max_total_tokens is None or max_length is None
                else max_total_tokens - max_length
            ),
            pad_prompt_tokens=pad_prompt_tokens,
            truncate_prompt_tokens=truncate_prompt_tokens,
            do_lower_case=do_lower_case,
            add_special_tokens=add_special_tokens,
            needs_detokenization=needs_detokenization,
        )

    def get_encode_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.encode`."""
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens
284
285
286
287
        elif max_length is None and self.max_input_tokens is not None:
            # This prevents tokenization from taking up more resources than necessary
            # while still failing `self._token_len_check` as expected by users
            max_length = self.max_input_tokens + 1
288
289

        return dict(
290
            truncation=max_length is not None,
291
292
293
294
            max_length=max_length,
            add_special_tokens=self.add_special_tokens,
        )

295
296
297
298
299
300
301
302
303
304
305
306
307
    def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply length checks to prompt text if necessary."""
        max_input_tokens = self.max_input_tokens
        if max_input_tokens is None:
            return text

        if self.truncate_prompt_tokens is None and tokenizer is not None:
            max_input_chars = max_input_tokens * tokenizer.max_chars_per_token

            if len(text) > max_input_chars:
                # To save resources, fail the request outright without even
                # attempting tokenization
                raise VLLMValidationError(
308
309
310
311
312
313
314
315
                    f"This model's maximum context length is "
                    f"{self.max_total_tokens} tokens. However, you requested "
                    f"{self.max_output_tokens} output tokens and your prompt "
                    f"contains {len(text)} characters (more than "
                    f"{max_input_chars} characters, which is the upper bound "
                    f"for {max_input_tokens} input tokens). "
                    f"Please reduce the length of the input prompt or the "
                    f"number of requested output tokens.",
316
317
318
                    parameter="input_text",
                    value=len(text),
                )
319
320
321

        return text

322
323
324
325
    def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply lowercase to prompt text if necessary."""
        return text.lower() if self.do_lower_case else text

326
327
    def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply all validators to prompt text."""
328
329
330
331
        for validator in (
            self._text_len_check,
            self._text_lowercase,
        ):
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            text = validator(tokenizer, text)

        return text

    def apply_pre_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TextPrompt,
    ) -> TextPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run before tokenization occurs.
        """
        prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])

        return prompt

351
352
    def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply padding to prompt tokens if necessary."""
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        pad_length = self.pad_prompt_tokens
        if pad_length is not None and pad_length < 0:
            pad_length = self.max_input_tokens

        if pad_length is None or pad_length <= len(tokens):
            return tokens

        if tokenizer is None:
            raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
        if not isinstance(tokens, list):
            raise ValueError("Cannot pad tokens for embedding inputs")

        return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))

367
368
    def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply truncation to prompt tokens if necessary."""
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens

        if max_length is None or max_length >= len(tokens):
            return tokens
        if max_length == 0:
            return tokens[:0]

        if getattr(tokenizer, "truncation_side", "left") == "left":
            return tokens[-max_length:]

        return tokens[:max_length]

383
384
    def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply length checks to prompt tokens if necessary."""
385
        max_input_tokens = self.max_input_tokens
386
387
        if max_input_tokens is None:
            return tokens
388

389
        if len(tokens) > max_input_tokens:
390
391
392
393
394
395
            token_count = len(tokens)
            # The tokenizer may have truncated the prompt to
            # max_input_tokens + 1 (see get_encode_kwargs), so the
            # actual prompt length could be larger.
            qualifier = "at least " if token_count == max_input_tokens + 1 else ""
            total = token_count + self.max_output_tokens
396
            raise VLLMValidationError(
397
398
399
400
401
402
403
                f"This model's maximum context length is "
                f"{self.max_total_tokens} tokens. However, you requested "
                f"{self.max_output_tokens} output tokens and your prompt "
                f"contains {qualifier}{token_count} input tokens, "
                f"for a total of {qualifier}{total} tokens. "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
404
                parameter="input_tokens",
405
                value=token_count,
406
407
408
409
410
411
412
            )

        return tokens

    def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply all validators to a token sequence."""
        for validator in (
413
414
415
            self._token_padding,
            self._token_truncation,
            self._token_len_check,
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        ):
            tokens = validator(tokenizer, tokens)

        return tokens

    def apply_post_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TokensPrompt | EmbedsPrompt,
    ) -> TokensPrompt | EmbedsPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run after tokenization occurs.
        """
        if "prompt_token_ids" in prompt:
            prompt["prompt_token_ids"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_token_ids"],  # type: ignore[typeddict-item]
            )
        if "prompt_embeds" in prompt:
            prompt["prompt_embeds"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_embeds"],  # type: ignore[typeddict-item]
            )

        return prompt