params.py 13.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar

from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
14
15

    from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
16
17
18
else:
    torch = LazyLoader("torch", globals(), "torch")

19
20
    ChatTemplateContentFormatOption = object

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
logger = init_logger(__name__)


_S = TypeVar("_S", list[int], "torch.Tensor")


def merge_kwargs(
    defaults: dict[str, Any] | None,
    overrides: dict[str, Any] | None,
    /,
    *,
    unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
    if defaults is None:
        defaults = {}
    if overrides is None:
        overrides = {}

    return defaults | {k: v for k, v in overrides.items() if v not in unset_values}


@dataclass(frozen=True)
class ChatParams:
    """Configuration to control how to parse chat messages."""

    chat_template: str | None = None
    """The chat template to apply."""

49
    chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    """The format of the chat template."""

    chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
    """The kwargs to pass to the chat template."""

    def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
        if not default_chat_template_kwargs:
            return self

        return ChatParams(
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                default_chat_template_kwargs,
                self.chat_template_kwargs,
            ),
        )

    def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.apply_chat_template`."""
        return merge_kwargs(
            self.chat_template_kwargs,
72
            dict(chat_template=self.chat_template, return_dict=False),
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        )


@dataclass(frozen=True)
class TokenizeParams:
    """Configuration to control how prompts are tokenized."""

    max_total_tokens: int | None
    """
    Maximum allowed number of input + output tokens.
    
    Usually, this refers to the model's context length.
    """

    max_output_tokens: int = 0
    """Maximum requested number of output tokens."""

    pad_prompt_tokens: int | None = None
    """
    Number of tokens to pad to:
    - `None` means no padding.
    - `-1` maps to `max_input_tokens`.
    """

    truncate_prompt_tokens: int | None = None
    """
    Number of tokens to keep:
    - `None` means no truncation.
    - `-1` maps to `max_input_tokens`.
    """

    do_lower_case: bool = False
    """Whether to normalize text to lower case before tokenization."""

    add_special_tokens: bool = True
    """Whether to add special tokens."""

    needs_detokenization: bool = False
    """
    Whether the tokenized prompt needs to contain the original text.

    Not to be confused with `SamplingParams.detokenize` which deals
    with the output generated by the model.
    """

    max_total_tokens_param: str = "max_total_tokens"
    """Override this to edit the message for validation errors."""

    max_output_tokens_param: str = "max_output_tokens"
    """Override this to edit the message for validation errors."""

    truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
    """Override this to edit the message for validation errors."""

    @property
    def max_input_tokens(self) -> int | None:
        """Maximum allowed number of input tokens."""
        if self.max_total_tokens is None:
            return None

        return self.max_total_tokens - self.max_output_tokens

    def __post_init__(self) -> None:
        max_total_tokens = self.max_total_tokens
        max_output_tokens = self.max_output_tokens
        max_input_tokens = self.max_input_tokens
        truncate_prompt_tokens = self.truncate_prompt_tokens

        if (
            max_output_tokens is not None
            and max_total_tokens is not None
            and max_output_tokens > max_total_tokens
        ):
            raise VLLMValidationError(
                f"{self.max_output_tokens_param}={max_output_tokens}"
                f"cannot be greater than "
                f"{self.max_total_tokens_param}={max_total_tokens=}. "
                f"Please request fewer output tokens.",
                parameter=self.max_output_tokens_param,
                value=max_output_tokens,
            )

        if (
            max_input_tokens is not None
            and truncate_prompt_tokens is not None
            and truncate_prompt_tokens > max_input_tokens
        ):
            raise VLLMValidationError(
                f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
                f"cannot be greater than {self.max_total_tokens_param} - "
                f"{self.max_output_tokens_param} = {max_input_tokens}. "
                f"Please request a smaller truncation size.",
                parameter=self.truncate_prompt_tokens_param,
                value=truncate_prompt_tokens,
            )

169
    def with_kwargs(self, **tokenization_kwargs: Any):
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
        pad_prompt_tokens = tokenization_kwargs.pop(
            "pad_prompt_tokens", self.pad_prompt_tokens
        )
        truncate_prompt_tokens = tokenization_kwargs.pop(
            "truncate_prompt_tokens", self.truncate_prompt_tokens
        )
        do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
        add_special_tokens = tokenization_kwargs.pop(
            "add_special_tokens", self.add_special_tokens
        )
        needs_detokenization = tokenization_kwargs.pop(
            "needs_detokenization", self.needs_detokenization
        )

        # https://huggingface.co/docs/transformers/en/pad_truncation
        if padding := tokenization_kwargs.pop("padding", None):
            if padding == "max_length":
                pad_prompt_tokens = max_length
            elif padding in (False, "do_not_pad"):
                pad_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["padding"] = padding

        if truncation := tokenization_kwargs.pop("truncation", None):
            if truncation in (True, "longest_first"):
                truncate_prompt_tokens = max_length
            elif truncation in (False, "do_not_truncate"):
                truncate_prompt_tokens = None
            else:
                # To emit the below warning
                tokenization_kwargs["truncation"] = truncation

        if tokenization_kwargs:
            logger.warning(
                "The following tokenization arguments are not supported "
                "by vLLM Renderer and will be ignored: %s",
                tokenization_kwargs,
            )

        max_total_tokens = self.max_total_tokens

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=(
                0
                if max_total_tokens is None or max_length is None
                else max_total_tokens - max_length
            ),
            pad_prompt_tokens=pad_prompt_tokens,
            truncate_prompt_tokens=truncate_prompt_tokens,
            do_lower_case=do_lower_case,
            add_special_tokens=add_special_tokens,
            needs_detokenization=needs_detokenization,
        )

    def get_encode_kwargs(self) -> dict[str, Any]:
        """The arguments to pass to `tokenizer.encode`."""
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens
232
233
234
235
        elif max_length is None and self.max_input_tokens is not None:
            # This prevents tokenization from taking up more resources than necessary
            # while still failing `self._token_len_check` as expected by users
            max_length = self.max_input_tokens + 1
236
237

        return dict(
238
            truncation=max_length is not None,
239
240
241
242
            max_length=max_length,
            add_special_tokens=self.add_special_tokens,
        )

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply length checks to prompt text if necessary."""
        max_input_tokens = self.max_input_tokens
        if max_input_tokens is None:
            return text

        if self.truncate_prompt_tokens is None and tokenizer is not None:
            max_input_chars = max_input_tokens * tokenizer.max_chars_per_token

            if len(text) > max_input_chars:
                # To save resources, fail the request outright without even
                # attempting tokenization
                raise VLLMValidationError(
                    f"You passed {len(text)} input characters "
                    f"and requested {self.max_output_tokens} output tokens. "
                    f"However, the model's context length is only "
                    f"{self.max_total_tokens} tokens, resulting in a maximum "
                    f"input length of {max_input_tokens} tokens "
                    f"(at most {max_input_chars} characters). "
                    f"Please reduce the length of the input prompt.",
                    parameter="input_text",
                    value=len(text),
                )
266
267
268

        return text

269
270
271
272
    def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply lowercase to prompt text if necessary."""
        return text.lower() if self.do_lower_case else text

273
274
    def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
        """Apply all validators to prompt text."""
275
276
277
278
        for validator in (
            self._text_len_check,
            self._text_lowercase,
        ):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
            text = validator(tokenizer, text)

        return text

    def apply_pre_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TextPrompt,
    ) -> TextPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run before tokenization occurs.
        """
        prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])

        return prompt

298
299
    def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply padding to prompt tokens if necessary."""
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        pad_length = self.pad_prompt_tokens
        if pad_length is not None and pad_length < 0:
            pad_length = self.max_input_tokens

        if pad_length is None or pad_length <= len(tokens):
            return tokens

        if tokenizer is None:
            raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
        if not isinstance(tokens, list):
            raise ValueError("Cannot pad tokens for embedding inputs")

        return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))

314
315
    def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply truncation to prompt tokens if necessary."""
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        max_length = self.truncate_prompt_tokens
        if max_length is not None and max_length < 0:
            max_length = self.max_input_tokens

        if max_length is None or max_length >= len(tokens):
            return tokens
        if max_length == 0:
            return tokens[:0]

        if getattr(tokenizer, "truncation_side", "left") == "left":
            return tokens[-max_length:]

        return tokens[:max_length]

330
331
    def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply length checks to prompt tokens if necessary."""
332
        max_input_tokens = self.max_input_tokens
333
334
        if max_input_tokens is None:
            return tokens
335

336
        if len(tokens) > max_input_tokens:
337
            raise VLLMValidationError(
338
339
                f"You passed {len(tokens)} input tokens "
                f"and requested {self.max_output_tokens} output tokens. "
340
                f"However, the model's context length is only "
341
342
343
                f"{self.max_total_tokens} tokens, resulting in a maximum "
                f"input length of {max_input_tokens} tokens. "
                f"Please reduce the length of the input prompt.",
344
345
346
347
348
349
350
351
352
                parameter="input_tokens",
                value=len(tokens),
            )

        return tokens

    def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
        """Apply all validators to a token sequence."""
        for validator in (
353
354
355
            self._token_padding,
            self._token_truncation,
            self._token_len_check,
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        ):
            tokens = validator(tokenizer, tokens)

        return tokens

    def apply_post_tokenization(
        self,
        tokenizer: TokenizerLike | None,
        prompt: TokensPrompt | EmbedsPrompt,
    ) -> TokensPrompt | EmbedsPrompt:
        """
        Ensure that the prompt meets the requirements set out by this config.
        If that is not possible, raise a `VLLMValidationError`.

        This method is run after tokenization occurs.
        """
        if "prompt_token_ids" in prompt:
            prompt["prompt_token_ids"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_token_ids"],  # type: ignore[typeddict-item]
            )
        if "prompt_embeds" in prompt:
            prompt["prompt_embeds"] = self._validate_tokens(  # type: ignore[typeddict-unknown-key]
                tokenizer,
                prompt["prompt_embeds"],  # type: ignore[typeddict-item]
            )

        return prompt