processor.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, cast
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
from vllm.inputs.preprocess import InputPreprocessor
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
15
from vllm.multimodal.cache import processor_cache_from_config
16
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
17
from vllm.multimodal.processing import EncDecMultiModalProcessor
18
from vllm.multimodal.utils import argsort_mm_positions
19
20
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
21
from vllm.transformers_utils.tokenizer import AnyTokenizer
22
from vllm.utils import length_from_prompt_token_ids_or_embeds
23
from vllm.v1.engine import EngineCoreRequest
24
from vllm.v1.metrics.stats import MultiModalCacheStats
25
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
26
from vllm.v1.structured_output.backend_lm_format_enforcer import (
27
28
    validate_structured_output_request_lm_format_enforcer,
)
29
from vllm.v1.structured_output.backend_outlines import (
30
31
32
    validate_structured_output_request_outlines,
)
from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar
33

34
35
logger = init_logger(__name__)

36
37
38
39

class Processor:
    def __init__(
        self,
40
        vllm_config: VllmConfig,
41
        tokenizer: AnyTokenizer | None,
42
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
43
    ) -> None:
44
45
46
47
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
48
        self.structured_outputs_config = vllm_config.structured_outputs_config
49

50
        self.generation_config_fields = self.model_config.try_get_generation_config()
51

52
        self.mm_registry = mm_registry
53
        self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
54

55
56
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
57
            tokenizer,
58
59
60
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )
61

62
    @property
63
    def tokenizer(self) -> AnyTokenizer | None:
64
65
66
        return self.input_preprocessor.tokenizer

    @tokenizer.setter
67
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
68
69
        self.input_preprocessor.tokenizer = tokenizer

70
71
    def _validate_logprobs(
        self,
72
        params: SamplingParams,
73
74
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
75
        if max_logprobs == -1:
76
77
            max_logprobs = self.model_config.get_vocab_size()

78
        # Validate sample logprobs.
79
80
81
82
83
84
85
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested sample logprobs of {num_logprobs}, "
86
87
                    f"which is greater than max allowed: {max_logprobs}"
                )
88
89

        # Validate prompt logprobs.
90
91
92
93
94
95
96
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
97
98
                    f"which is greater than max allowed: {max_logprobs}"
                )
99

100
    def _validate_sampling_params(
101
        self,
102
        params: SamplingParams,
103
    ) -> None:
104
        self._validate_structured_output(params)
105
        self._validate_logit_bias(params)
106

107
108
        if params.allowed_token_ids is None:
            return
109
110
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
111
112
113
114
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
115
        vocab_size = len(self.tokenizer)
116
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
117
            raise ValueError("allowed_token_ids contains out-of-vocab token id!")
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
137
138
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}"
            )
139

140
141
142
143
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
144
145
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
146
            raise ValueError("vLLM V1 does not yet support best_of.")
147
148
        # Logits processors not supported.
        if params.logits_processors:
149
150
151
            raise ValueError(
                "vLLM V1 does not support per request user provided logits processors."
            )
152
153
154

    def _validate_params(
        self,
155
        params: SamplingParams | PoolingParams,
156
157
158
159
160
161
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

162
163
        if isinstance(params, PoolingParams):
            return
164
165

        self._validate_logprobs(params)
166
        self._validate_sampling_params(params)
167
168
        self._validate_supported_sampling_params(params)

169
170
171
172
    def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
173
        Only checks lengths; `None` entries are allowed and will be
174
175
176
        auto-hashed downstream.
        """

177
        def _validate_single_prompt(single_prompt: dict | str) -> None:
178
179
180
181
182
183
184
185
186
187
            if not isinstance(single_prompt, dict):
                return
            mm_data = single_prompt.get("multi_modal_data")
            mm_uuids = single_prompt.get("multi_modal_uuids")
            if not mm_data or not mm_uuids:
                return

            for modality, items in mm_data.items():
                if modality in mm_uuids:
                    data_len = len(items) if isinstance(items, list) else 1
188
189
190
191
192
                    uuid_len = (
                        len(mm_uuids[modality])
                        if isinstance(mm_uuids[modality], list)
                        else 1
                    )
193
194
195
196
197
                    if uuid_len != data_len:
                        raise ValueError(
                            f"multi_modal_uuids for modality '{modality}' "
                            "must have same length as data: got "
                            f"{uuid_len} uuids vs "
198
199
                            f"{data_len} items."
                        )
200
201
202
                else:
                    raise ValueError(
                        f"multi_modal_uuids for modality '{modality}' must "
203
204
                        "be provided if multi_modal_data is provided."
                    )
205
206
207
208
209
210

        # Handle explicit encoder/decoder prompts or singleton prompt
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            enc = prompt.get("encoder_prompt")
            dec = prompt.get("decoder_prompt")
            if enc is not None:
211
                _validate_single_prompt(cast(dict | str, enc))
212
            if dec is not None:
213
                _validate_single_prompt(cast(dict | str, dec))
214
215
216
        else:
            _validate_single_prompt(prompt)  # type: ignore[arg-type]

217
    def _validate_lora(self, lora_request: LoRARequest | None) -> None:
218
219
220
221
222
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
223
224
225
            raise ValueError(
                f"Got lora_request {lora_request} but LoRA is not enabled!"
            )
226

227
228
229
230
231
232
        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
233
234
                "[lora_path]` to use the LoRA tokenizer."
            )
235

236
    def _validate_structured_output(self, params: SamplingParams) -> None:
237
        if not params.structured_outputs or not self.structured_outputs_config:
238
            return
239

240
        if self.model_config.skip_tokenizer_init and params.structured_outputs:
241
242
243
244
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

245
246
247
        backend = self.structured_outputs_config.backend
        if _backend := params.structured_outputs._backend:
            # Request-level backend selection is not supported.
248
249
250
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
251
            # using the `_backend_was_auto` field set in the params.
252
253
254
            if backend != _backend and not (
                backend == "auto" and params.structured_outputs._backend_was_auto
            ):
255
                raise ValueError(
256
257
258
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
259
260
                    "resolved by removing '_backend' from the request."
                )
261
        else:
262
            params.structured_outputs._backend = backend
263

264
        # Request content validation
265
266
267
268
        if (
            isinstance(params.structured_outputs.choice, list)
            and not params.structured_outputs.choice
        ):
269
            # It is invalid for choice to be an empty list
270
271
272
            raise ValueError(
                f"Choice '{params.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
273
274
275
276
277
278
        # Reject empty string grammar early to avoid engine-side crashes
        if (
            isinstance(params.structured_outputs.grammar, str)
            and params.structured_outputs.grammar.strip() == ""
        ):
            raise ValueError("structured_outputs.grammar cannot be an empty string")
279

280
        if backend.startswith("xgrammar"):
281
            # xgrammar with no fallback
282
            validate_xgrammar_grammar(params)
283
        elif backend.startswith("guidance"):
284
285
286
287
288
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
289
        elif backend == "outlines":
290
291
            # outlines backend
            validate_structured_output_request_outlines(params)
292
        elif backend == "lm-format-enforcer":
293
294
            # lm format enforcer backend
            validate_structured_output_request_lm_format_enforcer(params)
295
        else:
296
            # NOTE: backend must be "auto" here, because we have
297
            # checked supported_backends above.
298
299
300
301
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
302
            try:
303
                validate_xgrammar_grammar(params)
304
                params.structured_outputs._backend = "xgrammar"
305
            except ValueError:
306
307
                # The request either failed validation
                # or includes some jsonschema feature(s) that
308
                # are not supported in xgrammar. Fall back to guidance.
309
                validate_guidance_grammar(params, tokenizer=None)
310
                params.structured_outputs._backend = "guidance"
311
            # Remember that this backend was set automatically
312
            params.structured_outputs._backend_was_auto = True
313

314
    def _maybe_build_mm_uuids(
315
316
317
        self,
        request_id: str,
        prompt: PromptType,
318
    ) -> MultiModalUUIDDict | None:
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

341
        mm_uuids: dict[str, list[str | None] | str] = {}
342
343
        for modality, data in mm_data.items():
            n = len(data) if isinstance(data, list) else 1
344
            mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
345
        return mm_uuids
346

347
348
349
350
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
351
352
353
354
355
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
356
        priority: int = 0,
357
        data_parallel_rank: int | None = None,
358
    ) -> EngineCoreRequest:
359
        self._validate_lora(lora_request)
360
        self._validate_params(params)
361

362
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
363
364
365
366
367
368
369
        if data_parallel_rank is not None and not (
            0 <= data_parallel_rank < data_parallel_size
        ):
            raise ValueError(
                f"data_parallel_rank {data_parallel_rank} "
                f"is out of range [0, {data_parallel_size})."
            )
370

371
372
373
        if arrival_time is None:
            arrival_time = time.time()

374
375
376
        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

377
378
        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
379
380
381
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
382
383
384
385
386
        if (
            self.model_config.multimodal_config
            and self.model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.cache_config.enable_prefix_caching
        ):
387
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
388
        else:
389
390
391
392
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
            self._validate_multi_modal_uuids(prompt)
            if isinstance(prompt, dict):
393
394
395
                mm_uuids = cast(
                    MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
                )
396
            else:
397
                mm_uuids = None
398

399
400
401
402
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
403
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
404
            prompt,
405
            tokenization_kwargs=tokenization_kwargs,
406
            mm_uuids=mm_uuids,
407
        )
408
        from vllm.platforms import current_platform
409

410
411
412
413
414
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
415

416
        eos_token_id = self.input_preprocessor.get_eos_token_id()
417

418
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
419
420
        self._validate_model_inputs(encoder_inputs, decoder_inputs)

421
422
423
424
425
426
427
        # Mypy can be conservative for TypedDict unions; normalize access.
        if decoder_inputs["type"] == "embeds":
            prompt_token_ids = None
            prompt_embeds = decoder_inputs["prompt_embeds"]
        else:
            prompt_token_ids = decoder_inputs["prompt_token_ids"]
            prompt_embeds = None
428

429
430
431
432
433
434
435
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
436
                seq_len = length_from_prompt_token_ids_or_embeds(
437
438
439
                    prompt_token_ids, prompt_embeds
                )
                sampling_params.max_tokens = self.model_config.max_model_len - seq_len
440
            sampling_params.update_from_generation_config(
441
442
                self.generation_config_fields, eos_token_id
            )
443
            if self.tokenizer is not None:
444
                sampling_params.update_from_tokenizer(self.tokenizer)
445
446
        else:
            pooling_params = params.clone()
447

448
        # Multimodal related.
449
        mm_features: list[MultiModalFeatureSpec] | None = None
450

451
452
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
453
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
454
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
455
456
457
458

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
459
460
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

461
462
463
464
465
466
467
            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
468
469
470
                        mm_position=decoder_mm_positions[modality][idx],
                    )
                )
471

472
        return EngineCoreRequest(
473
            request_id=request_id,
474
475
            prompt_token_ids=prompt_token_ids,
            prompt_embeds=prompt_embeds,
476
            mm_features=mm_features,
477
            sampling_params=sampling_params,
478
            pooling_params=pooling_params,
479
480
481
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
482
            cache_salt=decoder_inputs.get("cache_salt"),
483
            priority=priority,
484
            data_parallel_rank=data_parallel_rank,
485
            trace_headers=trace_headers,
486
        )
487

488
    def _validate_model_inputs(
489
        self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs
490
    ):
491
        if encoder_inputs is not None:
492
            self._validate_model_input(encoder_inputs, prompt_type="encoder")
493

494
        self._validate_model_input(decoder_inputs, prompt_type="decoder")
495

496
497
498
499
500
501
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
502
        model_config = self.model_config
503

504
505
506
507
508
509
510
511
512
513
514
        prompt_ids = (
            None
            if prompt_inputs["type"] == "embeds"
            else prompt_inputs["prompt_token_ids"]
        )
        prompt_embeds = (
            prompt_inputs["prompt_embeds"]
            if prompt_inputs["type"] == "embeds"
            else None
        )
        prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
515
        if not prompt_ids:
516
517
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
518
519
            elif prompt_inputs["type"] == "embeds":
                pass  # Prompt embeds should not have prompt_ids.
520
521
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
522

523
524
        tokenizer = self.tokenizer
        if tokenizer is not None:
525
            max_input_id = max(prompt_ids or [], default=0)
526
527
528
529
530
531
532
533
534
535
536

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
537
538
539
540
            if max_input_id > max(
                tokenizer.max_token_id, self.model_config.get_vocab_size() - 1
            ):
                raise ValueError(f"Token id {max_input_id} is out of vocabulary")
541
542

        max_prompt_len = self.model_config.max_model_len
543
        if prompt_len > max_prompt_len:
544
545
546
547
548
549
550
551
552
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
553
                    return  # Skip encoder length check for Whisper
554
555

            if model_config.is_multimodal_model:
556
                suggestion = (
557
558
559
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
560
561
                    "of images, and possibly their aspect ratios as well."
                )
562
563
564
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
565
566
                    "number of text tokens."
                )
567
568

            raise ValueError(
569
                f"The {prompt_type} prompt (length {prompt_len}) is "
570
                f"longer than the maximum model length of {max_prompt_len}. "
571
572
                f"{suggestion}"
            )
573

574
575
576
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
577

578
579
580
581
        if (
            prompt_len == max_prompt_len
            and prompt_type == "decoder"
            and not model_config.is_multimodal_model
582
            and self.model_config.runner_type != "pooling"
583
584
585
586
587
588
589
590
591
592
593
        ):
            suggestion = (
                "Make sure that `max_model_len` is no smaller than the "
                "number of text tokens (prompt + requested output tokens)."
            )
            raise ValueError(
                f"The {prompt_type} prompt (length {prompt_len}) plus the number of "
                f"requested output tokens (at least 1) is longer than the maximum "
                f"model length of {max_prompt_len}. {suggestion}"
            )

594
    def stat_mm_cache(self) -> MultiModalCacheStats | None:
595
596
597
598
        return self.input_preprocessor.stat_mm_cache()

    def clear_mm_cache(self) -> None:
        self.input_preprocessor.clear_mm_cache()