input_processor.py 32.7 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
from collections.abc import Mapping
from typing import Any, Literal, cast

from vllm.config import VllmConfig
9
from vllm.exceptions import VLLMValidationError
10
from vllm.inputs.data import (
11
12
13
14
15
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    SingletonPrompt,
)
16
from vllm.inputs.parse import split_enc_dec_inputs
17
18
19
20
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.multimodal.encoder_budget import MultiModalBudget
22
23
24
25
26
27
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFeatureSpec,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
28
from vllm.multimodal.processing.context import set_request_id
29
30
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
31
from vllm.renderers import BaseRenderer
32
from vllm.renderers.inputs import DictPrompt, TokPrompt
33
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
34
from vllm.tasks import POOLING_TASKS, SupportedTask
35
36
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
37
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
38
from vllm.utils.torch_utils import set_default_torch_num_threads
39
40
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
41
42
43
44
from vllm.v1.structured_output.backend_guidance import (
    has_guidance_unsupported_json_features,
    validate_guidance_grammar,
)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar

logger = init_logger(__name__)


class InputProcessor:
    def __init__(
        self,
        vllm_config: VllmConfig,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
    ) -> None:
        self.vllm_config = vllm_config
63
        self.model_config = model_config = vllm_config.model_config
64
65
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
66
        self.scheduler_config = vllm_config.scheduler_config
67
        self.structured_outputs_config = vllm_config.structured_outputs_config
68
        self.observability_config = vllm_config.observability_config
69

70
        self.generation_config_fields = model_config.try_get_generation_config()
71
72

        self.mm_registry = mm_registry
73
        self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
74

75
76
77
78
        self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
        self.mm_encoder_cache_size = 0
        self.skip_prompt_length_check = False
        if self.supports_mm_inputs:
79
80
            mm_budget = MultiModalBudget(vllm_config, mm_registry)
            self.mm_encoder_cache_size = mm_budget.encoder_cache_size
81
82
83
            self.skip_prompt_length_check = (
                mm_budget.processor.info.skip_prompt_length_check
            )
84
            mm_budget.reset_cache()  # Not used anymore
85
86

        self.input_preprocessor = InputPreprocessor(
87
88
            model_config,
            self.observability_config,
89
90
91
92
93
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )

    @property
94
    def tokenizer(self) -> TokenizerLike | None:
95
96
        return self.input_preprocessor.tokenizer

97
98
99
100
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_preprocessor.get_tokenizer()

    @property
101
    def renderer(self) -> BaseRenderer:
102
103
        return self.input_preprocessor.renderer

104
105
106
107
108
109
110
111
112
113
114
115
116
117
    def _validate_logprobs(
        self,
        params: SamplingParams,
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        if max_logprobs == -1:
            max_logprobs = self.model_config.get_vocab_size()

        # Validate sample logprobs.
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
118
                raise VLLMValidationError(
119
                    f"Requested sample logprobs of {num_logprobs}, "
120
121
122
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="logprobs",
                    value=num_logprobs,
123
124
125
126
127
128
129
130
                )

        # Validate prompt logprobs.
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
131
                raise VLLMValidationError(
132
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
133
134
135
                    f"which is greater than max allowed: {max_logprobs}",
                    parameter="prompt_logprobs",
                    value=num_prompt_logprobs,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                )

    def _validate_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        self._validate_structured_output(params)
        self._validate_logit_bias(params)

        if params.allowed_token_ids is None:
            return
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
        vocab_size = len(self.tokenizer)
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
            raise ValueError("allowed_token_ids contains out-of-vocab token id!")

    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
173
            raise VLLMValidationError(
174
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
175
176
177
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
                parameter="logit_bias",
                value=invalid_token_ids,
178
179
180
181
182
183
184
185
186
            )

    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        # Logits processors not supported.
        if params.logits_processors:
            raise ValueError(
187
                "vLLM V1 does not support per request user-provided logits processors."
188
            )
189
190
191
192
193
194
195
196
197
198

        # Some sampling parameters are not yet compatible with spec decoding.
        if self.vllm_config.speculative_config is not None and (
            params.min_tokens > 1 or params.min_p > _SAMPLING_EPS or params.logit_bias
        ):
            raise ValueError(
                "The min_tokens, min_p, and logit_bias sampling parameters "
                "are not yet supported with speculative decoding."
            )

199
200
201
    def _validate_params(
        self,
        params: SamplingParams | PoolingParams,
202
203
204
        # TODO: Validate generation tasks as well once `supported_tasks`
        # is passed to all `process_inputs` calls
        supported_tasks: tuple[SupportedTask, ...] | None,
205
206
207
208
209
210
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """
        if isinstance(params, PoolingParams):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            if supported_tasks is None:
                raise RuntimeError("`supported_tasks` must be passed for pooling")

            supported_pooling_tasks = [
                task for task in supported_tasks if task in POOLING_TASKS
            ]

            if params.task is None:
                if not supported_pooling_tasks:
                    raise ValueError("Pooling tasks are not supported")

                if "token_embed" in supported_pooling_tasks:
                    params.task = "token_embed"
                elif "token_classify" in supported_pooling_tasks:
                    params.task = "token_classify"
                elif "plugin" in supported_pooling_tasks:
                    params.task = "plugin"

            if params.task not in supported_pooling_tasks:
                raise ValueError(
                    f"Unsupported task: {params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )

            params.verify(self.model_config)

237
238
239
240
241
242
            return

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

243
244
    def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        mm_processor = self.input_preprocessor._get_mm_processor()
245
        return mm_processor.info.parse_mm_data(mm_data)
246
247

    def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
248
249
        if not isinstance(prompt, dict):
            return
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

        mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
        mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
        if not mm_data and not mm_uuids:
            return

        mm_data_parsed = self._parse_mm_items(
            {k: v for k, v in mm_data.items() if v is not None}
        )
        mm_uuids_parsed = {
            k: [v] if isinstance(v, str) else v
            for k, v in mm_uuids.items()
            if v is not None
        }

        # NOTE: Include the keys corresponding to `None`
        modalities = mm_data.keys() | mm_uuids.keys()

        for modality in modalities:
            data_items = cast(
                ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
            )
            uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))

            if len(data_items) > 0:
                if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
                    raise ValueError(
                        f"If given, multi_modal_uuids[{modality!r}] must have "
                        f"same length as multi_modal_data[{modality!r}], but "
                        f"got {len(uuid_items)} vs {len(data_items)}."
                    )

                for i, item in enumerate(data_items):
                    if item is None:
                        if not uuid_items:
                            raise ValueError(
                                f"multi_modal_data[{modality!r}][{i}] is empty but "
                                f"multi_modal_uuids[{modality!r}] is missing."
                            )

                        if uuid_items[i] is None:
                            raise ValueError(
                                f"multi_modal_data[{modality!r}][{i}] is empty but "
                                f"multi_modal_uuids[{modality!r}][{i}] is missing."
                            )
            else:
                if len(uuid_items) == 0:
                    raise ValueError(
                        f"multi_modal_data[{modality!r}] is empty but "
                        f"multi_modal_uuids[{modality!r}] is missing."
                    )

302
    def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
303
304
305
306
307
308
309
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
        Only checks lengths; `None` entries are allowed and will be
        auto-hashed downstream.
        """

310
311
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            self._validate_singleton_mm_uuids(prompt["encoder_prompt"])  # type: ignore[typeddict-item]
312

313
            if (dec_prompt := prompt["decoder_prompt"]) is not None:  # type: ignore[typeddict-item]
314
                self._validate_singleton_mm_uuids(dec_prompt)
315
        else:
316
            self._validate_singleton_mm_uuids(prompt)
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

    def _validate_lora(self, lora_request: LoRARequest | None) -> None:
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
            raise ValueError(
                f"Got lora_request {lora_request} but LoRA is not enabled!"
            )

        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
                "[lora_path]` to use the LoRA tokenizer."
            )

    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.structured_outputs or not self.structured_outputs_config:
            return

341
        if self.model_config.skip_tokenizer_init and params.structured_outputs:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

        backend = self.structured_outputs_config.backend
        if _backend := params.structured_outputs._backend:
            # Request-level backend selection is not supported.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_backend_was_auto` field set in the params.
            if backend != _backend and not (
                backend == "auto" and params.structured_outputs._backend_was_auto
            ):
                raise ValueError(
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request."
                )
        else:
            params.structured_outputs._backend = backend

        # Request content validation
        if (
            isinstance(params.structured_outputs.choice, list)
            and not params.structured_outputs.choice
        ):
            # It is invalid for choice to be an empty list
            raise ValueError(
                f"Choice '{params.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(params.structured_outputs.grammar, str)
            and params.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")

        if backend.startswith("xgrammar"):
            # xgrammar with no fallback
            validate_xgrammar_grammar(params)
        elif backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'guidance' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_guidance_grammar(params, tokenizer=None)
        elif backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
        elif backend == "lm-format-enforcer":
            # lm format enforcer backend
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_structured_output_request_lm_format_enforcer(params)
        else:
            # NOTE: backend must be "auto" here, because we have
            # checked supported_backends above.
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
            try:
                validate_xgrammar_grammar(params)
                params.structured_outputs._backend = "xgrammar"
            except ValueError:
                # The request either failed validation
                # or includes some jsonschema feature(s) that
                # are not supported in xgrammar.
422
423
424
425
426
427
428
429
430
431
432
433
434
435

                # Check if schema has features unsupported by guidance
                so_params = params.structured_outputs
                skip_guidance = False
                if so_params.json:
                    if isinstance(so_params.json, str):
                        import json

                        schema = json.loads(so_params.json)
                    else:
                        schema = so_params.json
                    skip_guidance = has_guidance_unsupported_json_features(schema)

                if isinstance(self.tokenizer, MistralTokenizer) or skip_guidance:
436
                    # Fall back to outlines if the tokenizer is Mistral
437
                    # or if schema contains features unsupported by guidance
438
439
440
441
442
443
444
445
446
                    validate_structured_output_request_outlines(params)
                    params.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
                    validate_guidance_grammar(params, tokenizer=None)
                    params.structured_outputs._backend = "guidance"
            # Remember that this backend was set automatically
            params.structured_outputs._backend_was_auto = True

447
448
449
450
        # Run post-init validation. This is also important to ensure subsequent
        # roundtrip serialization/deserialization won't fail.
        params.structured_outputs.__post_init__()

451
452
453
    def _extract_singleton_mm_data(
        self, prompt: SingletonPrompt
    ) -> MultiModalDataDict | None:
454
        if not isinstance(prompt, dict):
455
456
            return None

457
        return prompt.get("multi_modal_data")
458

459
460
461
462
463
    def _extract_mm_data(
        self, prompt: PromptType | DictPrompt | TokPrompt
    ) -> MultiModalDataDict | None:
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            return self._extract_singleton_mm_data(prompt["encoder_prompt"])  # type: ignore[typeddict-item]
464
465
466
        else:
            return self._extract_singleton_mm_data(prompt)

467
468
469
    def _maybe_build_mm_uuids(
        self,
        request_id: str,
470
        prompt: PromptType | DictPrompt | TokPrompt,
471
472
473
474
475
476
477
478
    ) -> MultiModalUUIDDict | None:
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """
479
        mm_data = self._extract_mm_data(prompt)
480
481
482
        if not mm_data:
            return None

483
484
485
486
487
488
489
490
        mm_items = self._parse_mm_items(
            {k: v for k, v in mm_data.items() if v is not None}
        )

        return {
            modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
            for modality, data_count in mm_items.get_all_counts().items()
        }
491

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def _get_mm_identifier(
        self,
        mm_hash: str,
        lora_request: LoRARequest | None,
    ) -> str:
        """
        When enable_tower_connector_lora is True, multi-modal embeddings
        vary depending on the LoRA request. Therefore, the mm_hash must be
        generated based on the LoRA request to prevent incorrect cache hits.
        """
        if (
            lora_request is None
            or self.lora_config is None
            or not self.lora_config.enable_tower_connector_lora
        ):
            return mm_hash
        return f"{lora_request.lora_name}:{mm_hash}"

510
511
512
513
514
515
516
517
518
519
520
521
522
    @staticmethod
    def assign_request_id(request: EngineCoreRequest):
        """Replace the externally supplied request ID with an internal request ID
        that adds 8 random characters in order to ensure uniquness.
        """
        if request.external_req_id is not None:
            raise ValueError(
                "The external_req_id field should not be set on EngineCoreRequests"
                " passed to vLLM; use the request_id field."
            )
        request.external_req_id = request.request_id
        request.request_id = f"{request.external_req_id}-{random_uuid():.8}"

523
524
525
    def process_inputs(
        self,
        request_id: str,
526
        prompt: PromptType | DictPrompt | TokPrompt,
527
528
529
530
531
532
533
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
534
        supported_tasks: tuple[SupportedTask, ...] | None = None,
535
        resumable: bool = False,
536
537
    ) -> EngineCoreRequest:
        self._validate_lora(lora_request)
538
        self._validate_params(params, supported_tasks)
539

540
541
542
543
544
        parallel_config = self.vllm_config.parallel_config
        dp_size = parallel_config.data_parallel_size
        dp_local_size = parallel_config.data_parallel_size_local
        num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks):
545
546
            raise ValueError(
                f"data_parallel_rank {data_parallel_rank} "
547
                f"is out of range [0, {num_ranks})."
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
            )

        if arrival_time is None:
            arrival_time = time.time()

        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
        if (
            self.model_config.multimodal_config
            and self.model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.cache_config.enable_prefix_caching
        ):
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
        else:
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
570
            self._validate_mm_uuids(prompt)
571
572
573
574
575
576
577
578
579
580
581
            if isinstance(prompt, dict):
                mm_uuids = cast(
                    MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
                )
            else:
                mm_uuids = None

        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
582
        with set_request_id(request_id), set_default_torch_num_threads():
583
584
585
586
587
588
            processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            )

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        from vllm.platforms import current_platform

        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )

        eos_token_id = self.input_preprocessor.get_eos_token_id()

        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
        self._validate_model_inputs(encoder_inputs, decoder_inputs)

        # Mypy can be conservative for TypedDict unions; normalize access.
        if decoder_inputs["type"] == "embeds":
            prompt_token_ids = None
            prompt_embeds = decoder_inputs["prompt_embeds"]
        else:
            prompt_token_ids = decoder_inputs["prompt_token_ids"]
            prompt_embeds = None

        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                seq_len = length_from_prompt_token_ids_or_embeds(
                    prompt_token_ids, prompt_embeds
                )
                sampling_params.max_tokens = self.model_config.max_model_len - seq_len
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id
            )
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(self.tokenizer)
        else:
            pooling_params = params.clone()

        # Multimodal related.
        mm_features: list[MultiModalFeatureSpec] | None = None

        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
            decoder_mm_hashes = decoder_inputs["mm_hashes"]

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

            mm_features = []
            for modality, idx in sorted_mm_idxs:
644
                base_mm_hash = decoder_mm_hashes[modality][idx]
645
646
647
648
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
649
                        identifier=self._get_mm_identifier(
650
                            base_mm_hash,
651
652
                            lora_request,
                        ),
653
                        mm_position=decoder_mm_positions[modality][idx],
654
                        mm_hash=base_mm_hash,
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                    )
                )

        return EngineCoreRequest(
            request_id=request_id,
            prompt_token_ids=prompt_token_ids,
            prompt_embeds=prompt_embeds,
            mm_features=mm_features,
            sampling_params=sampling_params,
            pooling_params=pooling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
            cache_salt=decoder_inputs.get("cache_salt"),
            priority=priority,
            data_parallel_rank=data_parallel_rank,
            trace_headers=trace_headers,
672
            resumable=resumable,
673
674
        )

675
    def _validate_prompt_len(
676
        self,
677
        prompt_len: int,
678
679
        prompt_type: Literal["encoder", "decoder"],
    ):
680
681
        if self.skip_prompt_length_check:
            return
682

683
684
        if prompt_len == 0 and prompt_type == "decoder":
            raise ValueError(f"The {prompt_type} prompt cannot be empty")
685

686
687
688
689
690
691
        model_config = self.model_config
        max_prompt_len = (
            model_config.max_model_len
            if prompt_type == "decoder"
            else self.mm_encoder_cache_size
        )
692
        if prompt_len > max_prompt_len:
693
            if self.supports_mm_inputs:
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well."
                )
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens."
                )

            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}"
            )
711
        elif prompt_len == max_prompt_len and model_config.runner_type == "generate":
712
713
714
715
716
717
718
719
720
721
            suggestion = (
                "Make sure that `max_model_len` is no smaller than the "
                "number of text tokens (prompt + requested output tokens)."
            )
            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
                f"requested output tokens (at least 1) is longer than the maximum "
                f"model length of {max_prompt_len}. {suggestion}"
            )

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        prompt_type: Literal["encoder", "decoder"],
    ) -> None:
        model_config = self.model_config
        tokenizer = self.tokenizer

        prompt_ids = (
            None
            if prompt_inputs["type"] == "embeds"
            else prompt_inputs["prompt_token_ids"]
        )
        prompt_embeds = (
            prompt_inputs["prompt_embeds"]
            if prompt_inputs["type"] == "embeds"
            else None
        )

        prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
        self._validate_prompt_len(prompt_len, prompt_type)

        if prompt_inputs["type"] == "multimodal":
745
746
747
            decoder_mm_positions = prompt_inputs["mm_placeholders"]
            for modality, mm_positions in decoder_mm_positions.items():
                for mm_position in mm_positions:
748
                    embed_length = mm_position.get_num_embeds()
749
750
751
752
753
754
755
756
757
758
                    if embed_length > self.mm_encoder_cache_size:
                        raise ValueError(
                            f"The {prompt_type} prompt contains a(n) {modality} item "
                            f"with length {embed_length}, which exceeds the "
                            f"pre-allocated encoder cache size "
                            f"{self.mm_encoder_cache_size}. Please reduce the input "
                            f"size or increase the encoder cache size "
                            f"by setting --limit-mm-per-prompt at startup."
                        )

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        if prompt_ids and tokenizer is not None:
            max_input_id = max(prompt_ids, default=0)

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
            model_vocab_size = model_config.get_vocab_size()
            if max_input_id > max(tokenizer.max_token_id, model_vocab_size - 1):
                raise ValueError(f"Token id {max_input_id} is out of vocabulary")

    def _validate_model_inputs(
        self,
        encoder_inputs: SingletonInputs | None,
        decoder_inputs: SingletonInputs,
    ):
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs, prompt_type="encoder")

        self._validate_model_input(decoder_inputs, prompt_type="decoder")

786
787
788
789
790
    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        return self.input_preprocessor.stat_mm_cache()

    def clear_mm_cache(self) -> None:
        self.input_preprocessor.clear_mm_cache()
791
792
793
794

    def close(self) -> None:
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.close()