processor.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, Optional, Union
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
from vllm.multimodal.cache import processor_cache_from_config
15
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
16
from vllm.multimodal.processing import EncDecMultiModalProcessor
17
from vllm.multimodal.utils import argsort_mm_positions
18
19
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
21
from vllm.v1.engine import EngineCoreRequest
22
23
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
24
25
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer)
26
27
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines)
28
29
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
30
31
32
33
34
35


class Processor:

    def __init__(
        self,
36
        vllm_config: VllmConfig,
37
        tokenizer: TokenizerGroup,
38
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
39
40
    ):

41
42
43
44
45
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
46
47
        self.tokenizer = tokenizer

48
49
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
50

51
52
53
        self.mm_registry = mm_registry
        self.mm_processor_cache = processor_cache_from_config(
            vllm_config, mm_registry)
54

55
56
57
58
59
60
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
            self.tokenizer,
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )
61

62
63
    def _validate_logprobs(
        self,
64
        params: SamplingParams,
65
66
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
67
        if max_logprobs == -1:
68
69
            max_logprobs = self.model_config.get_vocab_size()

70
        # Validate sample logprobs.
71
72
73
74
75
76
77
78
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is is greater than max allowed: {max_logprobs}")
79
80

        # Validate prompt logprobs.
81
82
83
84
85
86
87
88
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is is greater than max allowed: {max_logprobs}")
89

90
    def _validate_sampling_params(
91
        self,
92
        params: SamplingParams,
93
        lora_request: Optional[LoRARequest],
94
    ) -> None:
95
        self._validate_structured_output(params)
96
        self._validate_logit_bias(params)
97

98
99
        if params.allowed_token_ids is None:
            return
100
101
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
102
103
104
105
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
106
107
        tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
        vocab_size = len(tokenizer)
108
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
109
            raise ValueError(
110
                "allowed_token_ids contains out-of-vocab token id!")
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

132
133
134
135
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
136
137
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
138
            raise ValueError("vLLM V1 does not yet support best_of.")
139
140
        # Logits processors not supported.
        if params.logits_processors:
141
            raise ValueError("vLLM V1 does not support per request "
142
143
144
145
146
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
147
        lora_request: Optional[LoRARequest],
148
149
150
151
152
153
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

154
155
        if isinstance(params, PoolingParams):
            return
156
157

        self._validate_logprobs(params)
158
        self._validate_sampling_params(params, lora_request)
159
160
        self._validate_supported_sampling_params(params)

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
        Only checks lengths; `None` entries are allowed and will be 
        auto-hashed downstream.
        """

        def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
            if not isinstance(single_prompt, dict):
                return
            mm_data = single_prompt.get("multi_modal_data")
            mm_uuids = single_prompt.get("multi_modal_uuids")
            if not mm_data or not mm_uuids:
                return

            for modality, items in mm_data.items():
                if modality in mm_uuids:
                    data_len = len(items) if isinstance(items, list) else 1
                    uuid_len = len(mm_uuids[modality]) if isinstance(
                        mm_uuids[modality], list) else 1
                    if uuid_len != data_len:
                        raise ValueError(
                            f"multi_modal_uuids for modality '{modality}' "
                            "must have same length as data: got "
                            f"{uuid_len} uuids vs "
                            f"{data_len} items.")
                else:
                    raise ValueError(
                        f"multi_modal_uuids for modality '{modality}' must "
                        "be provided if multi_modal_data is provided.")

        # Handle explicit encoder/decoder prompts or singleton prompt
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            enc = prompt.get("encoder_prompt")
            dec = prompt.get("decoder_prompt")
            if enc is not None:
                _validate_single_prompt(enc)
            if dec is not None:
                _validate_single_prompt(dec)
        else:
            _validate_single_prompt(prompt)  # type: ignore[arg-type]

204
205
206
207
208
    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

209
210
211
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
212

213
214
215
216
217
        if self.model_config.skip_tokenizer_init and params.guided_decoding:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

218
        engine_level_backend = self.decoding_config.backend
219
        if params.guided_decoding.backend:
220
221
222
223
224
225
            # Request-level backend selection is not supported in V1.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_auto` option set on the backend in the params.
            if (params.guided_decoding.backend != engine_level_backend
226
227
                    and not (engine_level_backend == "auto"
                             and params.guided_decoding.backend_was_auto)):
228
229
230
231
232
233
234
                raise ValueError(
                    "Request-level structured output backend selection is no "
                    "longer supported. The request specified "
                    f"'{params.guided_decoding.backend}', but vLLM was "
                    f"initialised with '{engine_level_backend}'. This error "
                    "can be resolved by removing backend selection from the "
                    "request.")
235
236
        else:
            params.guided_decoding.backend = engine_level_backend
237

238
        # Request content validation
239
240
241
242
243
244
        if (isinstance(params.guided_decoding.choice, list)
                and not params.guided_decoding.choice):
            # It is invalid for choice to be an empty list
            raise ValueError(f"Choice '{params.guided_decoding.choice}' "
                             "cannot be an empty list")

245
        if engine_level_backend.startswith("xgrammar"):
246
            # xgrammar with no fallback
247
            validate_xgrammar_grammar(params)
248
249
250
251
252
253
        elif engine_level_backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
254
255
256
        elif engine_level_backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
257
258
259
        elif engine_level_backend == "lm-format-enforcer":
            # lm format enforcer backend
            validate_structured_output_request_lm_format_enforcer(params)
260
261
262
        else:
            # NOTE: engine_level_backend must be "auto" here, because we have
            # checked supported_backends above.
263
264
265
266
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
267
            try:
268
                validate_xgrammar_grammar(params)
269
270
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
271
272
                # The request either failed validation
                # or includes some jsonschema feature(s) that
273
                # are not supported in xgrammar. Fall back to guidance.
274
                validate_guidance_grammar(params, tokenizer=None)
275
                params.guided_decoding.backend = "guidance"
276
            # Remember that this backend was set automatically
277
            params.guided_decoding.backend_was_auto = True
278

279
    def _maybe_build_mm_uuids(
280
281
282
        self,
        request_id: str,
        prompt: PromptType,
283
    ) -> Optional[MultiModalUUIDDict]:
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

306
        mm_uuids: MultiModalUUIDDict = {}
307
308
        for modality, data in mm_data.items():
            n = len(data) if isinstance(data, list) else 1
309
            mm_uuids[modality] = [
310
311
                f"{request_id}-{modality}-{i}" for i in range(n)
            ]
312
        return mm_uuids
313

314
315
316
317
318
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
319
        arrival_time: Optional[float] = None,
320
        lora_request: Optional[LoRARequest] = None,
321
        tokenization_kwargs: Optional[dict[str, Any]] = None,
322
323
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
324
        data_parallel_rank: Optional[int] = None,
325
    ) -> tuple[Optional[str], EngineCoreRequest]:
326

327
        # TODO(woosuk): Support pooling models.
328
        self._validate_lora(lora_request)
329
        self._validate_params(params, lora_request)
330

331
332
333
334
335
336
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank <
                                                   data_parallel_size):
            raise ValueError(f"data_parallel_rank {data_parallel_rank} "
                             f"is out of range [0, {data_parallel_size}).")

337
338
339
        if arrival_time is None:
            arrival_time = time.time()

340
341
342
        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

343
344
        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
345
346
347
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
348
349
350
        if (self.model_config.multimodal_config and
                self.model_config.multimodal_config.mm_processor_cache_gb == 0
                and not self.cache_config.enable_prefix_caching):
351
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
352
        else:
353
354
355
356
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
            self._validate_multi_modal_uuids(prompt)
            if isinstance(prompt, dict):
357
                mm_uuids = prompt.get("multi_modal_uuids")
358
            else:
359
                mm_uuids = None
360

361
362
363
364
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
365
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
366
            prompt,
367
            tokenization_kwargs=tokenization_kwargs,
368
            lora_request=lora_request,
369
            mm_uuids=mm_uuids,
370
        )
371
372
373
374
375
376
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
377

378
379
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

380
        self._validate_model_inputs(processed_inputs, lora_request)
381

382
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
383

384
385
386
387
388
389
390
391
392
393
394
395
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                sampling_params.max_tokens = (
                    self.model_config.max_model_len -
                    len(decoder_inputs["prompt_token_ids"]))
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id)
396
397
398
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(
                    self.tokenizer.get_lora_tokenizer(lora_request))
399
400
        else:
            pooling_params = params.clone()
401

402
        # Multimodal related.
403
404
        mm_features: Optional[list[MultiModalFeatureSpec]] = None

405
406
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
407
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
408
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
409
410
411
412

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
413
414
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

415
416
417
418
419
420
421
422
            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
                        mm_position=decoder_mm_positions[modality][idx]))
423

424
        return decoder_inputs.get("prompt"), EngineCoreRequest(
425
            request_id=request_id,
426
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
427
            mm_features=mm_features,
428
            sampling_params=sampling_params,
429
            pooling_params=pooling_params,
430
431
432
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
433
            cache_salt=decoder_inputs.get("cache_salt"),
434
            priority=priority,
435
            data_parallel_rank=data_parallel_rank,
436
            trace_headers=trace_headers,
437
        )
438

439
440
441
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
442
443
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

444
445
446
447
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
448

449
450
451
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
452

453
454
455
456
457
458
459
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
460
        model_config = self.model_config
461

462
463
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
464
465
466
467
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
468

469
470
471
472
473
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
            max_input_id = max(prompt_ids, default=0)
474
475
476
477
478
479
480
481
482
483
484
485
486

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
            if max_input_id > max(tokenizer.max_token_id,
                                  self.model_config.get_vocab_size() - 1):
487
488
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")
489
490

        max_prompt_len = self.model_config.max_model_len
491
        if len(prompt_ids) > max_prompt_len:
492
493
494
495
496
497
498
499
500
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
501
                    return  # Skip encoder length check for Whisper
502
503

            if model_config.is_multimodal_model:
504
                suggestion = (
505
506
507
508
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
509
510
511
512
513
514
515
516
517
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
518

519
520
521
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
522
523
524

    def clear_cache(self) -> None:
        self.input_preprocessor.clear_cache()