input_processor.py 18.1 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
from collections.abc import Mapping
6
from typing import Any, Literal
7

8
import vllm.envs as envs
9
from vllm.config import VllmConfig
10
from vllm.inputs.data import (
11
12
13
14
    ProcessorInputs,
    PromptType,
    SingletonInputs,
)
15
from vllm.inputs.parse import split_enc_dec_inputs
16
17
18
19
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
20
from vllm.multimodal.encoder_budget import MultiModalBudget
21
22
23
from vllm.multimodal.inputs import (
    MultiModalFeatureSpec,
)
24
from vllm.multimodal.utils import argsort_mm_positions
25
from vllm.platforms import current_platform
26
from vllm.pooling_params import PoolingParams
27
from vllm.renderers import BaseRenderer, renderer_from_config
28
from vllm.sampling_params import SamplingParams
29
from vllm.tasks import GENERATION_TASKS, POOLING_TASKS, SupportedTask
30
from vllm.tokenizers import TokenizerLike
31
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
32
from vllm.utils.jsontree import json_iter_leaves
33
34
35
36
37
38
39
40
41
from vllm.v1.engine import EngineCoreRequest

logger = init_logger(__name__)


class InputProcessor:
    def __init__(
        self,
        vllm_config: VllmConfig,
42
43
        renderer: BaseRenderer | None = None,
        *,
44
45
46
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
    ) -> None:
        self.vllm_config = vllm_config
47
        self.model_config = model_config = vllm_config.model_config
48
49
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
50
        self.scheduler_config = vllm_config.scheduler_config
51
        self.speculative_config = vllm_config.speculative_config
52
        self.structured_outputs_config = vllm_config.structured_outputs_config
53
        self.observability_config = vllm_config.observability_config
54

55
        self.generation_config_fields = model_config.try_get_generation_config()
56

57
        self.renderer = renderer or renderer_from_config(vllm_config)
58

59
60
61
62
        self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
        self.mm_encoder_cache_size = 0
        self.skip_prompt_length_check = False
        if self.supports_mm_inputs:
63
64
            mm_budget = MultiModalBudget(vllm_config, mm_registry)
            self.mm_encoder_cache_size = mm_budget.encoder_cache_size
65
66
67
            self.skip_prompt_length_check = (
                mm_budget.processor.info.skip_prompt_length_check
            )
68
            mm_budget.reset_cache()  # Not used anymore
69
70

        self.input_preprocessor = InputPreprocessor(
71
            vllm_config,
72
73
            renderer=renderer,
            mm_registry=mm_registry,
74
75
76
        )

    @property
77
    def tokenizer(self) -> TokenizerLike | None:
78
        return self.renderer.tokenizer
79

80
    def get_tokenizer(self) -> TokenizerLike:
81
        return self.renderer.get_tokenizer()
82

83
84
85
    def _validate_params(
        self,
        params: SamplingParams | PoolingParams,
86
87
        supported_tasks: tuple[SupportedTask, ...],
    ) -> None:
88
89
        """Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
        if isinstance(params, SamplingParams):
90
91
92
93
94
95
            supported_generation_tasks = [
                task for task in supported_tasks if task in GENERATION_TASKS
            ]
            if not supported_generation_tasks:
                raise ValueError("This model does not support generation")

96
97
98
99
100
101
            params.verify(
                self.model_config,
                self.speculative_config,
                self.structured_outputs_config,
                self.tokenizer,
            )
102
103
104
105
106
107
108
109
110
111

            if (
                params.thinking_token_budget is not None
                and self.vllm_config.reasoning_config is None
            ):
                raise ValueError(
                    "thinking_token_budget is set but reasoning_config is "
                    "not configured. Please set --reasoning-config to use "
                    "thinking_token_budget."
                )
112
        elif isinstance(params, PoolingParams):
113
114
115
            supported_pooling_tasks = [
                task for task in supported_tasks if task in POOLING_TASKS
            ]
116
117
            if not supported_pooling_tasks:
                raise ValueError("This model does not support pooling")
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

            if params.task is None:
                if "token_embed" in supported_pooling_tasks:
                    params.task = "token_embed"
                elif "token_classify" in supported_pooling_tasks:
                    params.task = "token_classify"
                elif "plugin" in supported_pooling_tasks:
                    params.task = "plugin"

            if params.task not in supported_pooling_tasks:
                raise ValueError(
                    f"Unsupported task: {params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )

            params.verify(self.model_config)
134
135
136
137
138
        else:
            raise TypeError(
                f"params must be either SamplingParams or PoolingParams, "
                f"but got {type(params).__name__}"
            )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    def _validate_lora(self, lora_request: LoRARequest | None) -> None:
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
            raise ValueError(
                f"Got lora_request {lora_request} but LoRA is not enabled!"
            )

        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
                "[lora_path]` to use the LoRA tokenizer."
            )

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def _get_mm_identifier(
        self,
        mm_hash: str,
        lora_request: LoRARequest | None,
    ) -> str:
        """
        When enable_tower_connector_lora is True, multi-modal embeddings
        vary depending on the LoRA request. Therefore, the mm_hash must be
        generated based on the LoRA request to prevent incorrect cache hits.
        """
        if (
            lora_request is None
            or self.lora_config is None
            or not self.lora_config.enable_tower_connector_lora
        ):
            return mm_hash
        return f"{lora_request.lora_name}:{mm_hash}"

177
178
179
    @staticmethod
    def assign_request_id(request: EngineCoreRequest):
        """Replace the externally supplied request ID with an internal request ID
Jiayi Yan's avatar
Jiayi Yan committed
180
        that adds 8 random characters in order to ensure uniqueness.
181
182
183
184
185
186
187
        """
        if request.external_req_id is not None:
            raise ValueError(
                "The external_req_id field should not be set on EngineCoreRequests"
                " passed to vLLM; use the request_id field."
            )
        request.external_req_id = request.request_id
188
189
190
191
192
193
194
195
        if envs.VLLM_DISABLE_REQUEST_ID_RANDOMIZATION:
            logger.warning_once(
                "VLLM_DISABLE_REQUEST_ID_RANDOMIZATION is set and will be "
                "removed in a future release. Duplicate externally-provided "
                "request IDs may cause failures and/or subtle correctness errors."
            )
        else:
            request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
196

197
198
199
    def process_inputs(
        self,
        request_id: str,
200
        prompt: PromptType | ProcessorInputs,
201
        params: SamplingParams | PoolingParams,
202
        supported_tasks: tuple[SupportedTask, ...],
203
204
205
206
207
208
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
209
        resumable: bool = False,
210
    ) -> EngineCoreRequest:
211
        self._validate_params(params, supported_tasks)
212
        self._validate_lora(lora_request)
213

214
215
216
217
218
        parallel_config = self.vllm_config.parallel_config
        dp_size = parallel_config.data_parallel_size
        dp_local_size = parallel_config.data_parallel_size_local
        num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks):
219
220
            raise ValueError(
                f"data_parallel_rank {data_parallel_rank} "
221
                f"is out of range [0, {num_ranks})."
222
223
            )

224
        if isinstance(prompt, dict) and "type" in prompt:
225
226
227
228
229
230
231
            if tokenization_kwargs:
                logger.warning_once(
                    "Passing tokenization_kwargs to InputProcessor is deprecated "
                    "and will be removed in v0.18. You should instead pass "
                    "them to Renderer.render_cmpl() or Renderer.render_chat()."
                )

232
233
            if arrival_time is None:
                arrival_time = prompt.get("arrival_time", time.time())  # type: ignore[assignment]
234

235
            processed_inputs: ProcessorInputs = prompt  # type: ignore[assignment]
236
        else:
237
238
239
240
241
242
            logger.warning_once(
                "Passing raw prompts to InputProcessor is deprecated "
                "and will be removed in v0.18. You should instead pass "
                "the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
            )

243
244
            if arrival_time is None:
                arrival_time = time.time()
245

246
            processed_inputs = self.input_preprocessor.preprocess(
247
248
249
250
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )

251
        current_platform.validate_request(processed_inputs, params)
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
        self._validate_model_inputs(encoder_inputs, decoder_inputs)

        # Mypy can be conservative for TypedDict unions; normalize access.
        if decoder_inputs["type"] == "embeds":
            prompt_token_ids = None
            prompt_embeds = decoder_inputs["prompt_embeds"]
        else:
            prompt_token_ids = decoder_inputs["prompt_token_ids"]
            prompt_embeds = None

        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                seq_len = length_from_prompt_token_ids_or_embeds(
                    prompt_token_ids, prompt_embeds
                )
                sampling_params.max_tokens = self.model_config.max_model_len - seq_len
275

276
            sampling_params.update_from_generation_config(
277
                self.generation_config_fields,
278
                self.renderer.get_eos_token_id(),
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            )
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(self.tokenizer)
        else:
            pooling_params = params.clone()

        # Multimodal related.
        mm_features: list[MultiModalFeatureSpec] | None = None

        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
            decoder_mm_hashes = decoder_inputs["mm_hashes"]

293
294
295
296
297
298
299
300
301
            if not all(
                isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)
            ):
                raise ValueError(
                    f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. "
                    "This is likely due to an incorrect custom implementation of "
                    "MultiModalProcessor.apply method."
                )

302
303
304
305
306
307
308
            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

            mm_features = []
            for modality, idx in sorted_mm_idxs:
309
                base_mm_hash = decoder_mm_hashes[modality][idx]
310
311
312
313
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
314
                        identifier=self._get_mm_identifier(
315
                            base_mm_hash,
316
317
                            lora_request,
                        ),
318
                        mm_position=decoder_mm_positions[modality][idx],
319
                        mm_hash=base_mm_hash,
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                    )
                )

        return EngineCoreRequest(
            request_id=request_id,
            prompt_token_ids=prompt_token_ids,
            prompt_embeds=prompt_embeds,
            mm_features=mm_features,
            sampling_params=sampling_params,
            pooling_params=pooling_params,
            arrival_time=arrival_time,
            lora_request=lora_request,
            cache_salt=decoder_inputs.get("cache_salt"),
            priority=priority,
            data_parallel_rank=data_parallel_rank,
            trace_headers=trace_headers,
336
            resumable=resumable,
337
338
        )

339
    def _validate_prompt_len(
340
        self,
341
        prompt_len: int,
342
343
        prompt_type: Literal["encoder", "decoder"],
    ):
344
345
        if self.skip_prompt_length_check:
            return
346

347
348
        if prompt_len == 0 and prompt_type == "decoder":
            raise ValueError(f"The {prompt_type} prompt cannot be empty")
349

350
351
352
353
354
355
        model_config = self.model_config
        max_prompt_len = (
            model_config.max_model_len
            if prompt_type == "decoder"
            else self.mm_encoder_cache_size
        )
356
        if prompt_len > max_prompt_len:
357
            if self.supports_mm_inputs:
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well."
                )
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens."
                )

            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}"
            )
375
        elif prompt_len == max_prompt_len and model_config.runner_type == "generate":
376
377
378
379
380
381
382
383
384
385
            suggestion = (
                "Make sure that `max_model_len` is no smaller than the "
                "number of text tokens (prompt + requested output tokens)."
            )
            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
                f"requested output tokens (at least 1) is longer than the maximum "
                f"model length of {max_prompt_len}. {suggestion}"
            )

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        prompt_type: Literal["encoder", "decoder"],
    ) -> None:
        model_config = self.model_config
        tokenizer = self.tokenizer

        prompt_ids = (
            None
            if prompt_inputs["type"] == "embeds"
            else prompt_inputs["prompt_token_ids"]
        )
        prompt_embeds = (
            prompt_inputs["prompt_embeds"]
            if prompt_inputs["type"] == "embeds"
            else None
        )

        prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
        self._validate_prompt_len(prompt_len, prompt_type)

        if prompt_inputs["type"] == "multimodal":
409
410
411
            decoder_mm_positions = prompt_inputs["mm_placeholders"]
            for modality, mm_positions in decoder_mm_positions.items():
                for mm_position in mm_positions:
412
                    embed_length = mm_position.get_num_embeds()
413
414
415
416
417
418
419
420
421
422
                    if embed_length > self.mm_encoder_cache_size:
                        raise ValueError(
                            f"The {prompt_type} prompt contains a(n) {modality} item "
                            f"with length {embed_length}, which exceeds the "
                            f"pre-allocated encoder cache size "
                            f"{self.mm_encoder_cache_size}. Please reduce the input "
                            f"size or increase the encoder cache size "
                            f"by setting --limit-mm-per-prompt at startup."
                        )

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        if prompt_ids and tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
            model_vocab_size = model_config.get_vocab_size()
            if max_input_id > max(tokenizer.max_token_id, model_vocab_size - 1):
                raise ValueError(f"Token id {max_input_id} is out of vocabulary")

    def _validate_model_inputs(
        self,
        encoder_inputs: SingletonInputs | None,
        decoder_inputs: SingletonInputs,
    ):
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs, prompt_type="encoder")

        self._validate_model_input(decoder_inputs, prompt_type="decoder")