processor.py 18.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping, Sequence
6
from typing import Any, Literal, Optional, Union
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
14
15
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
                             MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
16
from vllm.multimodal.processing import EncDecMultiModalProcessor
17
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
18
19
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
21
from vllm.v1.engine import EngineCoreRequest
22
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
23
24
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
25
26
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines)
27
28
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
29
30
31
32
33
34


class Processor:

    def __init__(
        self,
35
        vllm_config: VllmConfig,
36
        tokenizer: TokenizerGroup,
37
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
38
39
    ):

40
41
42
43
44
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
45
46
        self.tokenizer = tokenizer

47
48
49
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
        self.input_preprocessor = InputPreprocessor(self.model_config,
50
51
                                                    self.tokenizer,
                                                    mm_registry)
52

53
54
        self.mm_input_cache_client = MirroredProcessingCache(self.model_config)

55
        # Multi-modal hasher (for images)
56
        self.use_hash = self.mm_input_cache_client.use_cache or \
57
            self.cache_config.enable_prefix_caching
58

59
60
61
62
    @property
    def mm_registry(self):
        return self.input_preprocessor.mm_registry

63
64
    def _validate_logprobs(
        self,
65
        params: SamplingParams,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        # Validate sample logprobs.
        if params.logprobs and params.logprobs > max_logprobs:
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

80
    def _validate_sampling_params(
81
        self,
82
        params: SamplingParams,
83
        lora_request: Optional[LoRARequest],
84
    ) -> None:
85
        self._validate_structured_output(params)
86
        self._validate_logit_bias(params)
87

88
89
        if params.allowed_token_ids is None:
            return
90
91
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
92
93
        tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
        vocab_size = len(tokenizer)
94
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
95
            raise ValueError(
96
                "allowed_token_ids contains out-of-vocab token id!")
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

118
119
120
121
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
122
123
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
124
            raise ValueError("vLLM V1 does not yet support best_of.")
125
126
        # Logits processors not supported.
        if params.logits_processors:
127
            raise ValueError("vLLM V1 does not support per request "
128
129
130
131
132
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
133
        lora_request: Optional[LoRARequest],
134
135
136
137
138
139
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

140
141
        if isinstance(params, PoolingParams):
            return
142
143

        self._validate_logprobs(params)
144
        self._validate_sampling_params(params, lora_request)
145
146
147
148
149
150
151
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

152
153
154
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
155

156
157
158
159
160
        if self.model_config.skip_tokenizer_init and params.guided_decoding:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

161
        engine_level_backend = self.decoding_config.backend
162
        if params.guided_decoding.backend:
163
164
165
166
167
168
            # Request-level backend selection is not supported in V1.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_auto` option set on the backend in the params.
            if (params.guided_decoding.backend != engine_level_backend
169
170
                    and not (engine_level_backend == "auto"
                             and params.guided_decoding.backend_was_auto)):
171
172
173
174
175
176
177
                raise ValueError(
                    "Request-level structured output backend selection is no "
                    "longer supported. The request specified "
                    f"'{params.guided_decoding.backend}', but vLLM was "
                    f"initialised with '{engine_level_backend}'. This error "
                    "can be resolved by removing backend selection from the "
                    "request.")
178
179
        else:
            params.guided_decoding.backend = engine_level_backend
180

181
        # Request content validation
182
183
184
185
186
187
        if (isinstance(params.guided_decoding.choice, list)
                and not params.guided_decoding.choice):
            # It is invalid for choice to be an empty list
            raise ValueError(f"Choice '{params.guided_decoding.choice}' "
                             "cannot be an empty list")

188
        if engine_level_backend.startswith("xgrammar"):
189
            # xgrammar with no fallback
190
            validate_xgrammar_grammar(params)
191
192
193
194
195
196
        elif engine_level_backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
197
198
199
        elif engine_level_backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
200
201
202
        else:
            # NOTE: engine_level_backend must be "auto" here, because we have
            # checked supported_backends above.
203
204
205
206
207
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
208
                validate_xgrammar_grammar(params)
209
210
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
211
212
                # The request either failed validation
                # or includes some jsonschema feature(s) that
213
                # are not supported in xgrammar. Fall back to guidance.
214
                validate_guidance_grammar(params, tokenizer=None)
215
                params.guided_decoding.backend = "guidance"
216
            # Remember that this backend was set automatically
217
            params.guided_decoding.backend_was_auto = True
218

219
220
221
222
223
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
224
        arrival_time: Optional[float] = None,
225
        lora_request: Optional[LoRARequest] = None,
226
        tokenization_kwargs: Optional[dict[str, Any]] = None,
227
228
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
229
        data_parallel_rank: Optional[int] = None,
230
    ) -> tuple[Optional[str], EngineCoreRequest]:
231

232
        # TODO(woosuk): Support pooling models.
233
        # TODO(woosuk): Support encoder-decoder models.
234
        self._validate_lora(lora_request)
235
        self._validate_params(params, lora_request)
236
237
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
238

239
240
241
242
243
244
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank <
                                                   data_parallel_size):
            raise ValueError(f"data_parallel_rank {data_parallel_rank} "
                             f"is out of range [0, {data_parallel_size}).")

245
246
247
        if arrival_time is None:
            arrival_time = time.time()

248
249
250
251
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
252
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
253
            prompt,
254
            tokenization_kwargs=tokenization_kwargs,
255
            lora_request=lora_request,
256
            return_mm_hashes=self.use_hash,
257
        )
258
259
260
261
262
263
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
264
265
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

266
        self._validate_model_inputs(processed_inputs, lora_request)
267

268
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
269
270
271
272
273

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                sampling_params.max_tokens = (
                    self.model_config.max_model_len -
                    len(decoder_inputs["prompt_token_ids"]))
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id)
            sampling_params.update_from_tokenizer(
                self.tokenizer.get_lora_tokenizer(lora_request))
        else:
            pooling_params = params.clone()
290

291
        # Multimodal related.
292
        sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
293
294
        sorted_mm_positions: Optional[list[PlaceholderRange]] = None
        sorted_mm_hashes: Optional[list[str]] = None
295
296
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
297
298
299
300
301

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            (
302
                sorted_item_modalities,
303
304
305
                sorted_mm_positions,
                sorted_mm_hashes,
            ) = merge_and_sort_multimodal_metadata(
306
307
                decoder_inputs["mm_placeholders"],
                decoder_inputs["mm_hashes"] if self.use_hash else None,
308
            )
309

310
311
312
313
314
315
316
            # The output of merged multi-modal processor (`decoder_mm_inputs`)
            # is a single MultiModalKwargs for all items from all modalities.
            # This code flattens kwargs for individual items in a list and
            # sorts them by each item's position in the input sequence if there
            # are multiple modalities.
            unique_modalities = set(sorted_item_modalities)
            if len(unique_modalities) > 1:
317
                orig_sorted_mm_inputs = []
318
                used_indices = {modality: 0 for modality in unique_modalities}
319

320
321
322
                for modality in sorted_item_modalities:
                    items = decoder_mm_inputs.get_items(modality)
                    item = items[used_indices[modality]]
323
324
325

                    orig_sorted_mm_inputs.append(
                        MultiModalKwargs.from_items([item]))
326
                    used_indices[modality] += 1
327
            else:
328
                orig_sorted_mm_inputs = [
329
330
331
                    MultiModalKwargs.from_items([item]) for item in
                    decoder_mm_inputs.get_items(sorted_item_modalities[0])
                ]
332

333
334
335
336
337
338
            if sorted_mm_hashes is not None:
                sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
                    orig_sorted_mm_inputs, sorted_mm_hashes)
            else:
                sorted_mm_inputs = orig_sorted_mm_inputs

339
        return decoder_inputs.get("prompt"), EngineCoreRequest(
340
            request_id=request_id,
341
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
342
343
344
345
            mm_inputs=sorted_mm_inputs,
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
346
            pooling_params=pooling_params,
347
348
349
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
350
            cache_salt=decoder_inputs.get("cache_salt"),
351
            priority=priority,
352
            data_parallel_rank=data_parallel_rank,
353
        )
354

355
356
357
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
358
359
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

360
361
362
363
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
364

365
366
367
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
368

369
370
371
372
373
374
375
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
376
        model_config = self.model_config
377

378
379
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
380
381
382
383
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
384

385
386
387
388
389
390
391
392
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")
393
394

        max_prompt_len = self.model_config.max_model_len
395
        if len(prompt_ids) > max_prompt_len:
396
397
398
399
400
401
402
403
404
405
406
407
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

            if model_config.is_multimodal_model:
408
                suggestion = (
409
410
411
412
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
413
414
415
416
417
418
419
420
421
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
422

423
424
425
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens