processor.py 19.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, Optional, Union
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
from vllm.multimodal.cache import processor_cache_from_config
15
from vllm.multimodal.inputs import MultiModalFeatureSpec
16
from vllm.multimodal.processing import EncDecMultiModalProcessor
17
from vllm.multimodal.utils import argsort_mm_positions
18
19
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
21
from vllm.v1.engine import EngineCoreRequest
22
23
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
24
25
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer)
26
27
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines)
28
29
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
30
31
32
33
34
35


class Processor:

    def __init__(
        self,
36
        vllm_config: VllmConfig,
37
        tokenizer: TokenizerGroup,
38
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
39
40
    ):

41
42
43
44
45
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
46
47
        self.tokenizer = tokenizer

48
49
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
50

51
52
53
        self.mm_registry = mm_registry
        self.mm_processor_cache = processor_cache_from_config(
            vllm_config, mm_registry)
54

55
56
57
58
59
60
        self.input_preprocessor = InputPreprocessor(
            self.model_config,
            self.tokenizer,
            mm_registry,
            mm_processor_cache=self.mm_processor_cache,
        )
61

62
63
    def _validate_logprobs(
        self,
64
        params: SamplingParams,
65
66
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
67
68
        if max_logprobs == -1:
            return
69
        # Validate sample logprobs.
70
71
        if params.logprobs and (params.logprobs == -1
                                or params.logprobs > max_logprobs):
72
73
74
75
76
77
78
79
80
81
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

82
    def _validate_sampling_params(
83
        self,
84
        params: SamplingParams,
85
        lora_request: Optional[LoRARequest],
86
    ) -> None:
87
        self._validate_structured_output(params)
88
        self._validate_logit_bias(params)
89

90
91
        if params.allowed_token_ids is None:
            return
92
93
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
94
95
96
97
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
98
99
        tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
        vocab_size = len(tokenizer)
100
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
101
            raise ValueError(
102
                "allowed_token_ids contains out-of-vocab token id!")
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

124
125
126
127
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
128
129
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
130
            raise ValueError("vLLM V1 does not yet support best_of.")
131
132
        # Logits processors not supported.
        if params.logits_processors:
133
            raise ValueError("vLLM V1 does not support per request "
134
135
136
137
138
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
139
        lora_request: Optional[LoRARequest],
140
141
142
143
144
145
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

146
147
        if isinstance(params, PoolingParams):
            return
148
149

        self._validate_logprobs(params)
150
        self._validate_sampling_params(params, lora_request)
151
152
153
154
155
156
157
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

158
159
160
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
161

162
163
164
165
166
        if self.model_config.skip_tokenizer_init and params.guided_decoding:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

167
        engine_level_backend = self.decoding_config.backend
168
        if params.guided_decoding.backend:
169
170
171
172
173
174
            # Request-level backend selection is not supported in V1.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_auto` option set on the backend in the params.
            if (params.guided_decoding.backend != engine_level_backend
175
176
                    and not (engine_level_backend == "auto"
                             and params.guided_decoding.backend_was_auto)):
177
178
179
180
181
182
183
                raise ValueError(
                    "Request-level structured output backend selection is no "
                    "longer supported. The request specified "
                    f"'{params.guided_decoding.backend}', but vLLM was "
                    f"initialised with '{engine_level_backend}'. This error "
                    "can be resolved by removing backend selection from the "
                    "request.")
184
185
        else:
            params.guided_decoding.backend = engine_level_backend
186

187
        # Request content validation
188
189
190
191
192
193
        if (isinstance(params.guided_decoding.choice, list)
                and not params.guided_decoding.choice):
            # It is invalid for choice to be an empty list
            raise ValueError(f"Choice '{params.guided_decoding.choice}' "
                             "cannot be an empty list")

194
        if engine_level_backend.startswith("xgrammar"):
195
            # xgrammar with no fallback
196
            validate_xgrammar_grammar(params)
197
198
199
200
201
202
        elif engine_level_backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
203
204
205
        elif engine_level_backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
206
207
208
        elif engine_level_backend == "lm-format-enforcer":
            # lm format enforcer backend
            validate_structured_output_request_lm_format_enforcer(params)
209
210
211
        else:
            # NOTE: engine_level_backend must be "auto" here, because we have
            # checked supported_backends above.
212
213
214
215
216
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
217
                validate_xgrammar_grammar(params)
218
219
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
220
221
                # The request either failed validation
                # or includes some jsonschema feature(s) that
222
                # are not supported in xgrammar. Fall back to guidance.
223
                validate_guidance_grammar(params, tokenizer=None)
224
                params.guided_decoding.backend = "guidance"
225
            # Remember that this backend was set automatically
226
            params.guided_decoding.backend_was_auto = True
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def _maybe_build_mm_hash_overrides(
        self,
        request_id: str,
        prompt: PromptType,
    ) -> Optional[dict[str, list[str]]]:
        """Build per-item multimodal hash overrides when enabled. In this case,
        multimodal data items are identified by their request id, modality and
        index rather than their content.

        Returns a dictionary of modality -> list[str] of overrides, or None if
        disabled or no multimodal data is present.
        """

        def _extract_mm_data(p: PromptType):
            if isinstance(p, dict) and "encoder_prompt" in p:
                enc = p.get("encoder_prompt")
                if isinstance(enc, dict):
                    return enc.get("multi_modal_data")
                return None
            if isinstance(p, dict):
                return p.get("multi_modal_data")
            return None

        mm_data = _extract_mm_data(prompt)
        if not mm_data:
            return None

        overrides: dict[str, list[str]] = {}
        for modality, data in mm_data.items():
            n = len(data) if isinstance(data, list) else 1
            overrides[modality] = [
                f"{request_id}-{modality}-{i}" for i in range(n)
            ]
        return overrides

263
264
265
266
267
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
268
        arrival_time: Optional[float] = None,
269
        lora_request: Optional[LoRARequest] = None,
270
        tokenization_kwargs: Optional[dict[str, Any]] = None,
271
272
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
273
        data_parallel_rank: Optional[int] = None,
274
    ) -> tuple[Optional[str], EngineCoreRequest]:
275

276
        # TODO(woosuk): Support pooling models.
277
        # TODO(woosuk): Support encoder-decoder models.
278
        self._validate_lora(lora_request)
279
        self._validate_params(params, lora_request)
280
281
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
282

283
284
285
286
287
288
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank <
                                                   data_parallel_size):
            raise ValueError(f"data_parallel_rank {data_parallel_rank} "
                             f"is out of range [0, {data_parallel_size}).")

289
290
291
        if arrival_time is None:
            arrival_time = time.time()

292
293
294
295
296
297
298
299
300
301
302
303
        # Optionally generate multimodal hash overrides based on request id.
        # NOTE: when users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore hashing is no longer necessary.
        if (self.model_config.multimodal_config and
                self.model_config.multimodal_config.mm_processor_cache_gb == 0
                and not self.cache_config.enable_prefix_caching):
            mm_hash_overrides = self._maybe_build_mm_hash_overrides(
                request_id, prompt)
        else:
            mm_hash_overrides = None

304
305
306
307
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
308
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
309
            prompt,
310
            tokenization_kwargs=tokenization_kwargs,
311
            lora_request=lora_request,
312
            mm_hash_overrides=mm_hash_overrides,
313
        )
314
315
316
317
318
319
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
320
321
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

322
        self._validate_model_inputs(processed_inputs, lora_request)
323

324
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
325
326
327
328
329

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

330
331
332
333
334
335
336
337
338
339
340
341
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                sampling_params.max_tokens = (
                    self.model_config.max_model_len -
                    len(decoder_inputs["prompt_token_ids"]))
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id)
342
343
344
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(
                    self.tokenizer.get_lora_tokenizer(lora_request))
345
346
        else:
            pooling_params = params.clone()
347

348
        # Multimodal related.
349
350
        mm_features: Optional[list[MultiModalFeatureSpec]] = None

351
352
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
353
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
354
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
355
356
357
358

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
359
360
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

361
362
363
364
365
366
367
368
            mm_features = []
            for modality, idx in sorted_mm_idxs:
                mm_features.append(
                    MultiModalFeatureSpec(
                        data=decoder_mm_inputs[modality][idx],
                        modality=modality,
                        identifier=decoder_mm_hashes[modality][idx],
                        mm_position=decoder_mm_positions[modality][idx]))
369

370
        return decoder_inputs.get("prompt"), EngineCoreRequest(
371
            request_id=request_id,
372
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
373
            mm_features=mm_features,
374
            sampling_params=sampling_params,
375
            pooling_params=pooling_params,
376
377
378
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
379
            cache_salt=decoder_inputs.get("cache_salt"),
380
            priority=priority,
381
            data_parallel_rank=data_parallel_rank,
382
        )
383

384
385
386
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
387
388
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

389
390
391
392
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
393

394
395
396
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
397

398
399
400
401
402
403
404
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
405
        model_config = self.model_config
406

407
408
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
409
410
411
412
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
413

414
415
416
417
418
419
420
421
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")
422
423

        max_prompt_len = self.model_config.max_model_len
424
        if len(prompt_ids) > max_prompt_len:
425
426
427
428
429
430
431
432
433
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
汪志鹏's avatar
汪志鹏 committed
434
                    return  # Skip encoder length check for Whisper and Donut
435
436

            if model_config.is_multimodal_model:
437
                suggestion = (
438
439
440
441
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
442
443
444
445
446
447
448
449
450
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
451

452
453
454
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens
455
456
457

    def clear_cache(self) -> None:
        self.input_preprocessor.clear_cache()