processor.py 26.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, cast
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
from vllm.inputs.preprocess import InputPreprocessor
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
15
from vllm.multimodal.cache import processor_cache_from_config
16
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
17
from vllm.multimodal.parse import MultiModalDataParser
18
from vllm.multimodal.processing import EncDecMultiModalProcessor
19
from vllm.multimodal.utils import argsort_mm_positions
20
21
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
22
from vllm.transformers_utils.tokenizer import AnyTokenizer
23
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
24
from vllm.utils import length_from_prompt_token_ids_or_embeds
25
from vllm.v1.engine import EngineCoreRequest
26
from vllm.v1.metrics.stats import MultiModalCacheStats
27
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
28
from vllm.v1.structured_output.backend_lm_format_enforcer import (
29
30
    validate_structured_output_request_lm_format_enforcer,
)
31
from vllm.v1.structured_output.backend_outlines import (
32
33
34
    validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
35

36
37
logger = init_logger(__name__)

38
39
40
41

class Processor:
    def __init__(
        self,
42
        vllm_config: VllmConfig,
43
        tokenizer: AnyTokenizer | None,
44
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
45
    ) -> None:
46
47
48
49
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
50
        self.structured_outputs_config = vllm_config.structured_outputs_config
51

52
        self.generation_config_fields = self.model_config.try_get_generation_config()
53

54
        self.mm_registry = mm_registry
55
        self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
56

57
58
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
59
            tokenizer,
60
61
62
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )
63

64
    @property
65
    def tokenizer(self) -> AnyTokenizer | None:
66
67
68
        return self.input_preprocessor.tokenizer

    @tokenizer.setter
69
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
70
71
        self.input_preprocessor.tokenizer = tokenizer

72
73
    def _validate_logprobs(
        self,
74
        params: SamplingParams,
75
76
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
77
        if max_logprobs == -1:
78
79
            max_logprobs = self.model_config.get_vocab_size()

80
        # Validate sample logprobs.
81
82
83
84
85
86
87
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested sample logprobs of {num_logprobs}, "
88
89
                    f"which is greater than max allowed: {max_logprobs}"
                )
90
91

        # Validate prompt logprobs.
92
93
94
95
96
97
98
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
99
100
                    f"which is greater than max allowed: {max_logprobs}"
                )
101

102
    def _validate_sampling_params(
103
        self,
104
        params: SamplingParams,
105
    ) -> None:
106
        self._validate_structured_output(params)
107
        self._validate_logit_bias(params)
108

109
110
        if params.allowed_token_ids is None:
            return
111
112
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
113
114
115
116
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
117
        vocab_size = len(self.tokenizer)
118
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
119
            raise ValueError("allowed_token_ids contains out-of-vocab token id!")
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
139
140
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
            )
141

142
143
144
145
146
147
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        # Logits processors not supported.
        if params.logits_processors:
148
149
150
            raise ValueError(
                "vLLM V1 does not support per request user provided logits processors."
            )
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        # Async scheduling + spec decode currently incompatible with some
        # sampling parameters.
        if (
            self.vllm_config.speculative_config is not None
            and self.vllm_config.scheduler_config.async_scheduling
            and (
                params.frequency_penalty != 0.0
                or params.presence_penalty != 0.0
                or params.repetition_penalty != 1.0
                or params.bad_words_token_ids
                or params.structured_outputs
            )
        ):
            raise ValueError(
                "async scheduling with spec decoding doesn't yet support "
                "penalties, bad words or structured outputs in sampling parameters."
            )
168
169
170

    def _validate_params(
        self,
171
        params: SamplingParams | PoolingParams,
172
173
174
175
176
177
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

178
179
        if isinstance(params, PoolingParams):
            return
180
181

        self._validate_logprobs(params)
182
        self._validate_sampling_params(params)
183
184
        self._validate_supported_sampling_params(params)

185
186
187
188
    def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
189
        Only checks lengths; `None` entries are allowed and will be
190
191
192
        auto-hashed downstream.
        """

193
        def _validate_single_prompt(single_prompt: dict | str) -> None:
194
195
196
197
198
199
200
201
202
203
            if not isinstance(single_prompt, dict):
                return
            mm_data = single_prompt.get("multi_modal_data")
            mm_uuids = single_prompt.get("multi_modal_uuids")
            if not mm_data or not mm_uuids:
                return

            for modality, items in mm_data.items():
                if modality in mm_uuids:
                    data_len = len(items) if isinstance(items, list) else 1
204
205
206
207
208
                    uuid_len = (
                        len(mm_uuids[modality])
                        if isinstance(mm_uuids[modality], list)
                        else 1
                    )
209
210
211
212
213
                    if uuid_len != data_len:
                        raise ValueError(
                            f"multi_modal_uuids for modality '{modality}' "
                            "must have same length as data: got "
                            f"{uuid_len} uuids vs "
214
215
                            f"{data_len} items."
                        )
216
217
218
                else:
                    raise ValueError(
                        f"multi_modal_uuids for modality '{modality}' must "
219
220
                        "be provided if multi_modal_data is provided."
                    )
221
222
223
224
225
226

        # Handle explicit encoder/decoder prompts or singleton prompt
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            enc = prompt.get("encoder_prompt")
            dec = prompt.get("decoder_prompt")
            if enc is not None:
227
                _validate_single_prompt(cast(dict | str, enc))
228
            if dec is not None:
229
                _validate_single_prompt(cast(dict | str, dec))
230
231
232
        else:
            _validate_single_prompt(prompt)  # type: ignore[arg-type]

233
    def _validate_lora(self, lora_request: LoRARequest | None) -> None:
234
235
236
237
238
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
239
240
241
            raise ValueError(
                f"Got lora_request {lora_request} but LoRA is not enabled!"
            )
242

243
244
245
246
247
248
        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
249
250
                "[lora_path]` to use the LoRA tokenizer."
            )
251

252
    def _validate_structured_output(self, params: SamplingParams) -> None:
253
        if not params.structured_outputs or not self.structured_outputs_config:
254
            return
255

256
        if self.model_config.skip_tokenizer_init and params.structured_outputs:
257
258
259
260
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

261
262
263
        backend = self.structured_outputs_config.backend
        if _backend := params.structured_outputs._backend:
            # Request-level backend selection is not supported.
264
265
266
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
267
            # using the `_backend_was_auto` field set in the params.
268
269
270
            if backend != _backend and not (
                backend == "auto" and params.structured_outputs._backend_was_auto
            ):
271
                raise ValueError(
272
273
274
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
275
276
                    "resolved by removing '_backend' from the request."
                )
277
        else:
278
            params.structured_outputs._backend = backend
279

280
        # Request content validation
281
282
283
284
        if (
            isinstance(params.structured_outputs.choice, list)
            and not params.structured_outputs.choice
        ):
285
            # It is invalid for choice to be an empty list
286
287
288
            raise ValueError(
                f"Choice '{params.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
289
290
291
292
293
294
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(params.structured_outputs.grammar, str)
            and params.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")
295

296
        if backend.startswith("xgrammar"):
297
            # xgrammar with no fallback
298
            validate_xgrammar_grammar(params)
299
        elif backend.startswith("guidance"):
300
301
302
303
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
304
305
306
307
308
309
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'guidance' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
310
            validate_guidance_grammar(params, tokenizer=None)
311
        elif backend == "outlines":
312
313
            # outlines backend
            validate_structured_output_request_outlines(params)
314
        elif backend == "lm-format-enforcer":
315
            # lm format enforcer backend
316
317
318
319
320
321
            if isinstance(self.tokenizer, MistralTokenizer):
                raise ValueError(
                    "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
                    "structured output backend. Please use ['xgrammar', 'outlines'] "
                    "backends or tokenizer_mode='hf' instead."
                )
322
            validate_structured_output_request_lm_format_enforcer(params)
323
        else:
324
            # NOTE: backend must be "auto" here, because we have
325
            # checked supported_backends above.
326
327
328
329
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
330
            try:
331
                validate_xgrammar_grammar(params)
332
                params.structured_outputs._backend = "xgrammar"
333
            except ValueError:
334
335
                # The request either failed validation
                # or includes some jsonschema feature(s) that
336
337
338
339
340
341
342
343
344
                # are not supported in xgrammar.
                if isinstance(self.tokenizer, MistralTokenizer):
                    # Fall back to outlines if the tokenizer is Mistral
                    validate_structured_output_request_outlines(params)
                    params.structured_outputs._backend = "outlines"
                else:
                    # Fall back to guidance by default.
                    validate_guidance_grammar(params, tokenizer=None)
                    params.structured_outputs._backend = "guidance"
345
            # Remember that this backend was set automatically
346
            params.structured_outputs._backend_was_auto = True
347

348
    def _maybe_build_mm_uuids(
349
350
351
        self,
        request_id: str,
        prompt: PromptType,
352
    ) -> MultiModalUUIDDict | None:
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

375
        mm_uuids: dict[str, list[str | None] | str] = {}
376
        for modality, data in mm_data.items():
377
378
379
380
381
382
            # Hash each item for embedding inputs.
            n = (
                len(data)
                if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
                else 1
            )
383
            mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
384
        return mm_uuids
385

386
387
388
389
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
390
391
392
393
394
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
395
        priority: int = 0,
396
        data_parallel_rank: int | None = None,
397
    ) -> EngineCoreRequest:
398
        self._validate_lora(lora_request)
399
        self._validate_params(params)
400

401
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
402
403
404
405
406
407
408
        if data_parallel_rank is not None and not (
            0 <= data_parallel_rank < data_parallel_size
        ):
            raise ValueError(
                f"data_parallel_rank {data_parallel_rank} "
                f"is out of range [0, {data_parallel_size})."
            )
409

410
411
412
        if arrival_time is None:
            arrival_time = time.time()

413
414
415
        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

416
417
        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
418
419
420
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
421
422
423
424
425
        if (
            self.model_config.multimodal_config
            and self.model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.cache_config.enable_prefix_caching
        ):
426
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
427
        else:
428
429
430
431
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
            self._validate_multi_modal_uuids(prompt)
            if isinstance(prompt, dict):
432
433
434
                mm_uuids = cast(
                    MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
                )
435
            else:
436
                mm_uuids = None
437

438
439
440
441
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
442
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
443
            prompt,
444
            tokenization_kwargs=tokenization_kwargs,
445
            mm_uuids=mm_uuids,
446
        )
447
        from vllm.platforms import current_platform
448

449
450
451
452
453
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
454

455
        eos_token_id = self.input_preprocessor.get_eos_token_id()
456

457
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
458
459
        self._validate_model_inputs(encoder_inputs, decoder_inputs)

460
461
462
463
464
465
466
        # Mypy can be conservative for TypedDict unions; normalize access.
        if decoder_inputs["type"] == "embeds":
            prompt_token_ids = None
            prompt_embeds = decoder_inputs["prompt_embeds"]
        else:
            prompt_token_ids = decoder_inputs["prompt_token_ids"]
            prompt_embeds = None
467

468
469
470
471
472
473
474
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
475
                seq_len = length_from_prompt_token_ids_or_embeds(
476
477
478
                    prompt_token_ids, prompt_embeds
                )
                sampling_params.max_tokens = self.model_config.max_model_len - seq_len
479
            sampling_params.update_from_generation_config(
480
481
                self.generation_config_fields, eos_token_id
            )
482
            if self.tokenizer is not None:
483
                sampling_params.update_from_tokenizer(self.tokenizer)
484
485
        else:
            pooling_params = params.clone()
486

487
        # Multimodal related.
488
        mm_features: list[MultiModalFeatureSpec] | None = None
489

490
491
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
492
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
493
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
494
495
496
497

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
498
499
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

500
501
502
503
504
505
506
            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
507
508
509
                        mm_position=decoder_mm_positions[modality][idx],
                    )
                )
510

511
        return EngineCoreRequest(
512
            request_id=request_id,
513
514
            prompt_token_ids=prompt_token_ids,
            prompt_embeds=prompt_embeds,
515
            mm_features=mm_features,
516
            sampling_params=sampling_params,
517
            pooling_params=pooling_params,
518
519
520
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
521
            cache_salt=decoder_inputs.get("cache_salt"),
522
            priority=priority,
523
            data_parallel_rank=data_parallel_rank,
524
            trace_headers=trace_headers,
525
        )
526

527
    def _validate_model_inputs(
528
        self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs
529
    ):
530
        if encoder_inputs is not None:
531
            self._validate_model_input(encoder_inputs, prompt_type="encoder")
532

533
        self._validate_model_input(decoder_inputs, prompt_type="decoder")
534

535
536
537
538
539
540
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
541
        model_config = self.model_config
542

543
544
545
546
547
548
549
550
551
552
553
        prompt_ids = (
            None
            if prompt_inputs["type"] == "embeds"
            else prompt_inputs["prompt_token_ids"]
        )
        prompt_embeds = (
            prompt_inputs["prompt_embeds"]
            if prompt_inputs["type"] == "embeds"
            else None
        )
        prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
554
        if not prompt_ids:
555
556
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
557
558
            elif prompt_inputs["type"] == "embeds":
                pass  # Prompt embeds should not have prompt_ids.
559
560
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
561

562
563
        tokenizer = self.tokenizer
        if tokenizer is not None:
564
            max_input_id = max(prompt_ids or [], default=0)
565
566
567
568
569
570
571
572
573
574
575

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
576
577
578
579
            if max_input_id > max(
                tokenizer.max_token_id, self.model_config.get_vocab_size() - 1
            ):
                raise ValueError(f"Token id {max_input_id} is out of vocabulary")
580
581

        max_prompt_len = self.model_config.max_model_len
582
        if prompt_len > max_prompt_len:
583
584
585
586
587
588
589
590
591
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
592
                    return  # Skip encoder length check for Whisper
593
594

            if model_config.is_multimodal_model:
595
                suggestion = (
596
597
598
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
599
600
                    "of images, and possibly their aspect ratios as well."
                )
601
602
603
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
604
605
                    "number of text tokens."
                )
606
607

            raise ValueError(
608
                f"The {prompt_type} prompt (length {prompt_len}) is "
609
                f"longer than the maximum model length of {max_prompt_len}. "
610
611
                f"{suggestion}"
            )
612

613
614
615
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
616

617
618
619
620
        if (
            prompt_len == max_prompt_len
            and prompt_type == "decoder"
            and not model_config.is_multimodal_model
621
            and self.model_config.runner_type != "pooling"
622
623
624
625
626
627
628
629
630
631
632
        ):
            suggestion = (
                "Make sure that `max_model_len` is no smaller than the "
                "number of text tokens (prompt + requested output tokens)."
            )
            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
                f"requested output tokens (at least 1) is longer than the maximum "
                f"model length of {max_prompt_len}. {suggestion}"
            )

633
    def stat_mm_cache(self) -> MultiModalCacheStats | None:
634
635
636
637
        return self.input_preprocessor.stat_mm_cache()

    def clear_mm_cache(self) -> None:
        self.input_preprocessor.clear_mm_cache()