processor.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, Optional, Union
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
from vllm.inputs.preprocess import InputPreprocessor
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
15
from vllm.multimodal.cache import processor_cache_from_config
16
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
17
from vllm.multimodal.processing import EncDecMultiModalProcessor
18
from vllm.multimodal.utils import argsort_mm_positions
19
20
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
21
from vllm.transformers_utils.tokenizer import AnyTokenizer
22
from vllm.v1.engine import EngineCoreRequest
23
24
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
25
26
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer)
27
28
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines)
29
30
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
31

32
33
logger = init_logger(__name__)

34
35
36
37
38

class Processor:

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        tokenizer: AnyTokenizer,
41
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
42
43
    ):

44
45
46
47
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
48
        self.structured_outputs_config = vllm_config.structured_outputs_config
49
50
        self.tokenizer = tokenizer

51
52
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
53

54
55
56
        self.mm_registry = mm_registry
        self.mm_processor_cache = processor_cache_from_config(
            vllm_config, mm_registry)
57

58
59
60
61
62
63
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
            self.tokenizer,
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )
64

65
66
    def _validate_logprobs(
        self,
67
        params: SamplingParams,
68
69
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
70
        if max_logprobs == -1:
71
72
            max_logprobs = self.model_config.get_vocab_size()

73
        # Validate sample logprobs.
74
75
76
77
78
79
80
81
        if params.logprobs:
            num_logprobs = params.logprobs
            if num_logprobs == -1:
                num_logprobs = self.model_config.get_vocab_size()
            if num_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested sample logprobs of {num_logprobs}, "
                    f"which is is greater than max allowed: {max_logprobs}")
82
83

        # Validate prompt logprobs.
84
85
86
87
88
89
90
91
        if params.prompt_logprobs:
            num_prompt_logprobs = params.prompt_logprobs
            if num_prompt_logprobs == -1:
                num_prompt_logprobs = self.model_config.get_vocab_size()
            if num_prompt_logprobs > max_logprobs:
                raise ValueError(
                    f"Requested prompt logprobs of {num_prompt_logprobs}, "
                    f"which is is greater than max allowed: {max_logprobs}")
92

93
    def _validate_sampling_params(
94
        self,
95
        params: SamplingParams,
96
    ) -> None:
97
        self._validate_structured_output(params)
98
        self._validate_logit_bias(params)
99

100
101
        if params.allowed_token_ids is None:
            return
102
103
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
104
105
106
107
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
108
        vocab_size = len(self.tokenizer)
109
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
110
            raise ValueError(
111
                "allowed_token_ids contains out-of-vocab token id!")
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

133
134
135
136
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
137
138
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
139
            raise ValueError("vLLM V1 does not yet support best_of.")
140
141
        # Logits processors not supported.
        if params.logits_processors:
142
            raise ValueError("vLLM V1 does not support per request "
143
144
145
146
147
148
149
150
151
152
153
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

154
155
        if isinstance(params, PoolingParams):
            return
156
157

        self._validate_logprobs(params)
158
        self._validate_sampling_params(params)
159
160
        self._validate_supported_sampling_params(params)

161
162
163
164
    def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
        """
        Validate that user-provided multi_modal_uuids align with
        multi_modal_data in the incoming request prompt(s).
165
        Only checks lengths; `None` entries are allowed and will be
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        auto-hashed downstream.
        """

        def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
            if not isinstance(single_prompt, dict):
                return
            mm_data = single_prompt.get("multi_modal_data")
            mm_uuids = single_prompt.get("multi_modal_uuids")
            if not mm_data or not mm_uuids:
                return

            for modality, items in mm_data.items():
                if modality in mm_uuids:
                    data_len = len(items) if isinstance(items, list) else 1
                    uuid_len = len(mm_uuids[modality]) if isinstance(
                        mm_uuids[modality], list) else 1
                    if uuid_len != data_len:
                        raise ValueError(
                            f"multi_modal_uuids for modality '{modality}' "
                            "must have same length as data: got "
                            f"{uuid_len} uuids vs "
                            f"{data_len} items.")
                else:
                    raise ValueError(
                        f"multi_modal_uuids for modality '{modality}' must "
                        "be provided if multi_modal_data is provided.")

        # Handle explicit encoder/decoder prompts or singleton prompt
        if isinstance(prompt, dict) and "encoder_prompt" in prompt:
            enc = prompt.get("encoder_prompt")
            dec = prompt.get("decoder_prompt")
            if enc is not None:
                _validate_single_prompt(enc)
            if dec is not None:
                _validate_single_prompt(dec)
        else:
            _validate_single_prompt(prompt)  # type: ignore[arg-type]

204
    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
205
206
207
208
209
        if lora_request is None:
            return

        # LoRA request passed in while LoRA is not enabled
        if not self.lora_config:
210
211
212
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

213
214
215
216
217
218
219
220
        if self.tokenizer is not None:
            logger.warning_once(
                "vLLM has deprecated support for supporting different "
                "tokenizers for different LoRAs. By default, vLLM uses base "
                "model's tokenizer. If you are using a LoRA "
                "with its own tokenizer, consider specifying `--tokenizer "
                "[lora_path]` to use the LoRA tokenizer.")

221
    def _validate_structured_output(self, params: SamplingParams) -> None:
222
        if not params.structured_outputs or not self.structured_outputs_config:
223
            return
224

225
        if self.model_config.skip_tokenizer_init and params.structured_outputs:
226
227
228
229
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

230
231
232
        backend = self.structured_outputs_config.backend
        if _backend := params.structured_outputs._backend:
            # Request-level backend selection is not supported.
233
234
235
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
236
237
238
239
            # using the `_backend_was_auto` field set in the params.
            if (backend != _backend
                    and not (backend == "auto"
                             and params.structured_outputs._backend_was_auto)):
240
                raise ValueError(
241
242
243
244
                    "Request-level structured output backend selection is not "
                    f"supported. The request specified '{_backend}', but vLLM "
                    f"was initialised with '{backend}'. This error can be "
                    "resolved by removing '_backend' from the request.")
245
        else:
246
            params.structured_outputs._backend = backend
247

248
        # Request content validation
249
250
        if (isinstance(params.structured_outputs.choice, list)
                and not params.structured_outputs.choice):
251
            # It is invalid for choice to be an empty list
252
253
254
            raise ValueError(
                f"Choice '{params.structured_outputs.choice}' cannot be an empty list"  # noqa: E501
            )
255

256
        if backend.startswith("xgrammar"):
257
            # xgrammar with no fallback
258
            validate_xgrammar_grammar(params)
259
        elif backend.startswith("guidance"):
260
261
262
263
264
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
265
        elif backend == "outlines":
266
267
            # outlines backend
            validate_structured_output_request_outlines(params)
268
        elif backend == "lm-format-enforcer":
269
270
            # lm format enforcer backend
            validate_structured_output_request_lm_format_enforcer(params)
271
        else:
272
            # NOTE: backend must be "auto" here, because we have
273
            # checked supported_backends above.
274
275
276
277
            # In this mode, we set opinionated defaults based on what we think
            # will satisfy the most use cases without having to worry about
            # this setting. We include fallback behavior here, but not with any
            # other setting where a specific backend was specified.
278
            try:
279
                validate_xgrammar_grammar(params)
280
                params.structured_outputs._backend = "xgrammar"
281
            except ValueError:
282
283
                # The request either failed validation
                # or includes some jsonschema feature(s) that
284
                # are not supported in xgrammar. Fall back to guidance.
285
                validate_guidance_grammar(params, tokenizer=None)
286
                params.structured_outputs._backend = "guidance"
287
            # Remember that this backend was set automatically
288
            params.structured_outputs._backend_was_auto = True
289

290
    def _maybe_build_mm_uuids(
291
292
293
        self,
        request_id: str,
        prompt: PromptType,
294
    ) -> Optional[MultiModalUUIDDict]:
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

317
        mm_uuids: MultiModalUUIDDict = {}
318
319
        for modality, data in mm_data.items():
            n = len(data) if isinstance(data, list) else 1
320
            mm_uuids[modality] = [
321
322
                f"{request_id}-{modality}-{i}" for i in range(n)
            ]
323
        return mm_uuids
324

325
326
327
328
329
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
330
        arrival_time: Optional[float] = None,
331
        lora_request: Optional[LoRARequest] = None,
332
        tokenization_kwargs: Optional[dict[str, Any]] = None,
333
334
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
335
        data_parallel_rank: Optional[int] = None,
336
    ) -> tuple[Optional[str], EngineCoreRequest]:
337

338
        # TODO(woosuk): Support pooling models.
339
        self._validate_lora(lora_request)
340
        self._validate_params(params)
341

342
343
344
345
346
347
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank <
                                                   data_parallel_size):
            raise ValueError(f"data_parallel_rank {data_parallel_rank} "
                             f"is out of range [0, {data_parallel_size}).")

348
349
350
        if arrival_time is None:
            arrival_time = time.time()

351
352
353
        # Optionally generate multimodal hash overrides to avoid hashing
        # multimodal data items by their content as their identifiers.

354
355
        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
356
357
358
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # request id-modality-index as multimodal hash overrides.
359
360
361
        if (self.model_config.multimodal_config and
                self.model_config.multimodal_config.mm_processor_cache_gb == 0
                and not self.cache_config.enable_prefix_caching):
362
            mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
363
        else:
364
365
366
367
            # Otherwise, use user-provided uuids as multimodal hash overrides
            # if provided.
            self._validate_multi_modal_uuids(prompt)
            if isinstance(prompt, dict):
368
                mm_uuids = prompt.get("multi_modal_uuids")
369
            else:
370
                mm_uuids = None
371

372
373
374
375
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
376
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
377
            prompt,
378
            tokenization_kwargs=tokenization_kwargs,
379
            mm_uuids=mm_uuids,
380
        )
381
382
383
384
385
386
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
387

388
        eos_token_id = self.input_preprocessor.get_eos_token_id()
389

390
        self._validate_model_inputs(processed_inputs)
391

392
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
393

394
395
396
397
398
399
400
401
402
403
404
405
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                sampling_params.max_tokens = (
                    self.model_config.max_model_len -
                    len(decoder_inputs["prompt_token_ids"]))
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id)
406
            if self.tokenizer is not None:
407
                sampling_params.update_from_tokenizer(self.tokenizer)
408
409
        else:
            pooling_params = params.clone()
410

411
        # Multimodal related.
412
413
        mm_features: Optional[list[MultiModalFeatureSpec]] = None

414
415
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
416
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
417
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
418
419
420
421

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
422
423
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

424
425
426
427
428
429
430
431
            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
                        mm_position=decoder_mm_positions[modality][idx]))
432

433
        return decoder_inputs.get("prompt"), EngineCoreRequest(
434
            request_id=request_id,
435
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
436
            mm_features=mm_features,
437
            sampling_params=sampling_params,
438
            pooling_params=pooling_params,
439
440
441
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
442
            cache_salt=decoder_inputs.get("cache_salt"),
443
            priority=priority,
444
            data_parallel_rank=data_parallel_rank,
445
            trace_headers=trace_headers,
446
        )
447

448
    def _validate_model_inputs(self, inputs: ProcessorInputs):
449
450
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

451
        if encoder_inputs is not None:
452
            self._validate_model_input(encoder_inputs, prompt_type="encoder")
453

454
        self._validate_model_input(decoder_inputs, prompt_type="decoder")
455

456
457
458
459
460
461
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
462
        model_config = self.model_config
463

464
465
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
466
467
468
469
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
470

471
472
473
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
474
            tokenizer = self.tokenizer
475
            max_input_id = max(prompt_ids, default=0)
476
477
478
479
480
481
482
483
484
485
486
487
488

            # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
            # self.model_config.get_vocab_size() is the model’s vocab size.
            # For Qwen3 models, the language model has extra tokens that do
            # not exist in the tokenizer, and vice versa for multimodal
            # placeholder tokens in some multimodal models.
            # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
            # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501

            # Here we take the max of the two to determine if a token id is
            # truly out-of-vocabulary.
            if max_input_id > max(tokenizer.max_token_id,
                                  self.model_config.get_vocab_size() - 1):
489
490
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")
491
492

        max_prompt_len = self.model_config.max_model_len
493
        if len(prompt_ids) > max_prompt_len:
494
495
496
497
498
499
500
501
502
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
503
                    return  # Skip encoder length check for Whisper
504
505

            if model_config.is_multimodal_model:
506
                suggestion = (
507
508
509
510
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
511
512
513
514
515
516
517
518
519
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
520

521
522
523
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
524
525
526

    def clear_cache(self) -> None:
        self.input_preprocessor.clear_cache()