input_processor.py 26.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
from collections.abc import Mapping
from typing import Any, Literal, cast

from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
22
23
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer,
)
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar

logger = init_logger(__name__)


class InputProcessor:
    def __init__(
        self,
        vllm_config: VllmConfig,
43
        tokenizer: TokenizerLike | None,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.structured_outputs_config = vllm_config.structured_outputs_config

        self.generation_config_fields = self.model_config.try_get_generation_config()

        self.mm_registry = mm_registry
        self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)

        self.input_preprocessor = InputPreprocessor(
58
            self.model_config,
59
60
61
62
63
64
            tokenizer,
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )

    @property
65
    def tokenizer(self) -> TokenizerLike | None:
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        return self.input_preprocessor.tokenizer

    def _validate_logprobs(
        self,
        params: SamplingParams,
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        if max_logprobs == -1:
            max_logprobs = self.model_config.get_vocab_size()

        # Validate sample logprobs.
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}"
                )

        # Validate prompt logprobs.
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is greater than max allowed: {max_logprobs}"
                )

    def _validate_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        self._validate_structured_output(params)
        self._validate_logit_bias(params)

        if params.allowed_token_ids is None:
            return
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
        vocab_size = len(self.tokenizer)
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
            raise ValueError("allowed_token_ids contains out-of-vocab token id!")

    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
            )

    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        # Logits processors not supported.
        if params.logits_processors:
            raise ValueError(
                "vLLM V1 does not support per request user provided logits processors."
            )
        # Async scheduling + spec decode currently incompatible with some
        # sampling parameters.
        if (
            self.vllm_config.speculative_config is not None
            and self.vllm_config.scheduler_config.async_scheduling
            and (
                params.frequency_penalty != 0.0
                or params.presence_penalty != 0.0
                or params.repetition_penalty != 1.0
                or params.bad_words_token_ids
                or params.structured_outputs
            )
        ):
            raise ValueError(
                "async scheduling with spec decoding doesn't yet support "
                "penalties, bad words or structured outputs in sampling parameters."
            )

    def _validate_params(
        self,
        params: SamplingParams | PoolingParams,
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

        if isinstance(params, PoolingParams):
            return

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

    def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
        Only checks lengths; `None` entries are allowed and will be
        auto-hashed downstream.
        """

        def _validate_single_prompt(single_prompt: dict | str) -> None:
            if not isinstance(single_prompt, dict):
                return
192

193
194
195
196
197
            mm_data = single_prompt.get("multi_modal_data")
            mm_uuids = single_prompt.get("multi_modal_uuids")
            if not mm_data or not mm_uuids:
                return

198
199
200
201
202
203
204
205
206
207
208
209
210
211
            import torch

            def _get_len(items: object):
                if isinstance(items, dict):  # Embedding inputs
                    return _get_len(next(iter(items.values()))) if items else 1

                if isinstance(items, list):
                    return len(items)
                if isinstance(items, torch.Tensor):
                    # To keep backwards compatibility for single item embedding input
                    return 1 if getattr(items, "_is_single_item", False) else len(items)

                return 1

212
213
            for modality, items in mm_data.items():
                if modality in mm_uuids:
214
215
                    data_len = _get_len(items)
                    uuid_len = _get_len(mm_uuids[modality])
216
217
                    if uuid_len != data_len:
                        raise ValueError(
218
                            f"multi_modal_uuids for modality {modality!r} "
219
                            "must have same length as data: got "
220
                            f"{uuid_len} uuids vs {data_len} items."
221
222
223
                        )
                else:
                    raise ValueError(
224
                        f"multi_modal_uuids for modality {modality!r} must "
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                        "be provided if multi_modal_data is provided."
                    )

        # Handle explicit encoder/decoder prompts or singleton prompt
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            enc = prompt.get("encoder_prompt")
            dec = prompt.get("decoder_prompt")
            if enc is not None:
                _validate_single_prompt(cast(dict | str, enc))
            if dec is not None:
                _validate_single_prompt(cast(dict | str, dec))
        else:
            _validate_single_prompt(prompt)  # type: ignore[arg-type]

    def _validate_lora(self, lora_request: LoRARequest | None) -> None:
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
            raise ValueError(
                f"Got lora_request {lora_request} but LoRA is not enabled!"
            )

        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
                "[lora_path]` to use the LoRA tokenizer."
            )

    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.structured_outputs or not self.structured_outputs_config:
            return

262
        if self.model_config.skip_tokenizer_init and params.structured_outputs:
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

        backend = self.structured_outputs_config.backend
        if _backend := params.structured_outputs._backend:
            # Request-level backend selection is not supported.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_backend_was_auto` field set in the params.
            if backend != _backend and not (
                backend == "auto" and params.structured_outputs._backend_was_auto
            ):
                raise ValueError(
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request."
                )
        else:
            params.structured_outputs._backend = backend

        # Request content validation
        if (
            isinstance(params.structured_outputs.choice, list)
            and not params.structured_outputs.choice
        ):
            # It is invalid for choice to be an empty list
            raise ValueError(
                f"Choice '{params.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(params.structured_outputs.grammar, str)
            and params.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")

        if backend.startswith("xgrammar"):
            # xgrammar with no fallback
            validate_xgrammar_grammar(params)
        elif backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'guidance' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_guidance_grammar(params, tokenizer=None)
        elif backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
        elif backend == "lm-format-enforcer":
            # lm format enforcer backend
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
            validate_structured_output_request_lm_format_enforcer(params)
        else:
            # NOTE: backend must be "auto" here, because we have
            # checked supported_backends above.
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
            try:
                validate_xgrammar_grammar(params)
                params.structured_outputs._backend = "xgrammar"
            except ValueError:
                # The request either failed validation
                # or includes some jsonschema feature(s) that
                # are not supported in xgrammar.
                if isinstance(self.tokenizer, MistralTokenizer):
                    # Fall back to outlines if the tokenizer is Mistral
                    validate_structured_output_request_outlines(params)
                    params.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
                    validate_guidance_grammar(params, tokenizer=None)
                    params.structured_outputs._backend = "guidance"
            # Remember that this backend was set automatically
            params.structured_outputs._backend_was_auto = True

    def _maybe_build_mm_uuids(
        self,
        request_id: str,
        prompt: PromptType,
    ) -> MultiModalUUIDDict | None:
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

        mm_uuids: dict[str, list[str | None] | str] = {}
        for modality, data in mm_data.items():
            # Hash each item for embedding inputs.
            n = (
                len(data)
                if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
                else 1
            )
            mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
        return mm_uuids

    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> EngineCoreRequest:
        self._validate_lora(lora_request)
        self._validate_params(params)

        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (
            0 <= data_parallel_rank < data_parallel_size
        ):
            raise ValueError(
                f"data_parallel_rank {data_parallel_rank} "
                f"is out of range [0, {data_parallel_size})."
            )

        if arrival_time is None:
            arrival_time = time.time()

        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
        if (
            self.model_config.multimodal_config
            and self.model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.cache_config.enable_prefix_caching
        ):
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
        else:
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
            self._validate_multi_modal_uuids(prompt)
            if isinstance(prompt, dict):
                mm_uuids = cast(
                    MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
                )
            else:
                mm_uuids = None

        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
            prompt,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
        from vllm.platforms import current_platform

        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )

        eos_token_id = self.input_preprocessor.get_eos_token_id()

        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
        self._validate_model_inputs(encoder_inputs, decoder_inputs)

        # Mypy can be conservative for TypedDict unions; normalize access.
        if decoder_inputs["type"] == "embeds":
            prompt_token_ids = None
            prompt_embeds = decoder_inputs["prompt_embeds"]
        else:
            prompt_token_ids = decoder_inputs["prompt_token_ids"]
            prompt_embeds = None

        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                seq_len = length_from_prompt_token_ids_or_embeds(
                    prompt_token_ids, prompt_embeds
                )
                sampling_params.max_tokens = self.model_config.max_model_len - seq_len
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id
            )
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(self.tokenizer)
        else:
            pooling_params = params.clone()

        # Multimodal related.
        mm_features: list[MultiModalFeatureSpec] | None = None

        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
            decoder_mm_hashes = decoder_inputs["mm_hashes"]

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
                        mm_position=decoder_mm_positions[modality][idx],
                    )
                )

        return EngineCoreRequest(
            request_id=request_id,
            prompt_token_ids=prompt_token_ids,
            prompt_embeds=prompt_embeds,
            mm_features=mm_features,
            sampling_params=sampling_params,
            pooling_params=pooling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
            cache_salt=decoder_inputs.get("cache_salt"),
            priority=priority,
            data_parallel_rank=data_parallel_rank,
            trace_headers=trace_headers,
        )

    def _validate_model_inputs(
        self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs
    ):
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs, prompt_type="encoder")

        self._validate_model_input(decoder_inputs, prompt_type="decoder")

    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
        model_config = self.model_config

        prompt_ids = (
            None
            if prompt_inputs["type"] == "embeds"
            else prompt_inputs["prompt_token_ids"]
        )
        prompt_embeds = (
            prompt_inputs["prompt_embeds"]
            if prompt_inputs["type"] == "embeds"
            else None
        )
        prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
        if not prompt_ids:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            elif prompt_inputs["type"] == "embeds":
                pass  # Prompt embeds should not have prompt_ids.
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")

        tokenizer = self.tokenizer
        if tokenizer is not None:
            max_input_id = max(prompt_ids or [], default=0)

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
            if max_input_id > max(
                tokenizer.max_token_id, self.model_config.get_vocab_size() - 1
            ):
                raise ValueError(f"Token id {max_input_id} is out of vocabulary")

        max_prompt_len = self.model_config.max_model_len
        if prompt_len > max_prompt_len:
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
592
                    model_config,
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

            if model_config.is_multimodal_model:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well."
                )
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens."
                )

            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}"
            )

            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens

        if (
            prompt_len == max_prompt_len
            and prompt_type == "decoder"
            and not model_config.is_multimodal_model
            and self.model_config.runner_type != "pooling"
        ):
            suggestion = (
                "Make sure that `max_model_len` is no smaller than the "
                "number of text tokens (prompt + requested output tokens)."
            )
            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
                f"requested output tokens (at least 1) is longer than the maximum "
                f"model length of {max_prompt_len}. {suggestion}"
            )

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        return self.input_preprocessor.stat_mm_cache()

    def clear_mm_cache(self) -> None:
        self.input_preprocessor.clear_mm_cache()