processor.py 16.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections.abc import Mapping, Sequence
5
from typing import Literal, Optional, Union
6

7
from vllm.config import VllmConfig
8
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
9
from vllm.inputs.parse import split_enc_dec_inputs
10
11
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
12
13
14
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
                             MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
15
from vllm.multimodal.processing import EncDecMultiModalProcessor
16
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
17
18
19
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
21
from vllm.v1.engine import EngineCoreRequest
22
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
23
24
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
25
26
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
27
28
29
30
31
32


class Processor:

    def __init__(
        self,
33
        vllm_config: VllmConfig,
34
        tokenizer: TokenizerGroup,
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
37
    ):

38
39
40
41
42
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
43
44
        self.tokenizer = tokenizer

45
46
47
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
        self.input_preprocessor = InputPreprocessor(self.model_config,
48
49
                                                    self.tokenizer,
                                                    mm_registry)
50

51
52
        self.mm_input_cache_client = MirroredProcessingCache(self.model_config)

53
        # Multi-modal hasher (for images)
54
        self.use_hash = self.mm_input_cache_client.use_cache or \
55
            self.cache_config.enable_prefix_caching
56

57
58
    def _validate_logprobs(
        self,
59
        params: SamplingParams,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        # Validate sample logprobs.
        if params.logprobs and params.logprobs > max_logprobs:
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

74
    def _validate_sampling_params(
75
        self,
76
        params: SamplingParams,
77
    ) -> None:
78
        self._validate_structured_output(params)
79
        self._validate_logit_bias(params)
80

81
82
        if params.allowed_token_ids is None:
            return
83
84
85
86
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        vocab_size = self.model_config.get_vocab_size()
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
87
            raise ValueError(
88
                "allowed_token_ids contains out-of-vocab token id!")
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

110
111
112
113
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
114
115
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
116
            raise ValueError("vLLM V1 does not yet support best_of.")
117
118
        # Logits processors not supported.
        if params.logits_processors:
119
            raise ValueError("vLLM V1 does not support per request "
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

        if not isinstance(params, SamplingParams):
            raise ValueError("V1 does not yet support Pooling models.")

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

143
144
145
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
146
147
148

        engine_level_backend = self.decoding_config.guided_decoding_backend
        if params.guided_decoding.backend:
149
150
151
152
153
154
155
156
            # Request-level backend selection is not supported in V1.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_auto` option set on the backend in the params.
            if (params.guided_decoding.backend != engine_level_backend
                    and not (engine_level_backend == "auto" and "_auto"
                             in params.guided_decoding.backend_options())):
157
158
159
160
161
162
163
                raise ValueError(
                    "Request-level structured output backend selection is no "
                    "longer supported. The request specified "
                    f"'{params.guided_decoding.backend}', but vLLM was "
                    f"initialised with '{engine_level_backend}'. This error "
                    "can be resolved by removing backend selection from the "
                    "request.")
164
165
        else:
            params.guided_decoding.backend = engine_level_backend
166

167
        # Request content validation
168
        if engine_level_backend.startswith("xgrammar"):
169
            # xgrammar with no fallback
170
            validate_xgrammar_grammar(params)
171
172
173
174
175
176
177
178
179
        elif engine_level_backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
        else:
            # NOTE: engine_level_backend must be "auto" here, because we have
            # checked supported_backends above.
180
181
182
183
184
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
185
                validate_xgrammar_grammar(params)
186
187
188
189
190
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
                # The request includes some jsonschema feature(s) that
                # are not supported in xgrammar. Fall back to guidance.
                params.guided_decoding.backend = "guidance"
191
192
            # Remember that this backend was set automatically
            params.guided_decoding.add_option("_auto")
193

194
195
196
197
198
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
199
        arrival_time: Optional[float] = None,
200
201
202
203
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
204
    ) -> tuple[Optional[str], EngineCoreRequest]:
205

206
        # TODO(woosuk): Support pooling models.
207
        # TODO(woosuk): Support encoder-decoder models.
208
        self._validate_lora(lora_request)
209
210
211
212
213
214
215
        self._validate_params(params)
        if priority != 0:
            raise ValueError("V1 does not support priority yet.")
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
        if prompt_adapter_request is not None:
            raise ValueError("V1 does not support prompt_adapter_request.")
216

217
218
219
        if arrival_time is None:
            arrival_time = time.time()

220
221
222
223
224
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
        # 3. Apply prompt adapter to prompt token ids if one exists.
225
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
226
227
228
            prompt,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
229
            return_mm_hashes=self.use_hash,
230
        )
231
232
233
234
235
236
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
237
238
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

239
        self._validate_model_inputs(processed_inputs, lora_request)
240

241
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
242
243
244
245
246

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

247
        assert isinstance(params, SamplingParams)
248
        # TODO: can we avoid cloning here in multiproc case?
249
        sampling_params = params.clone()
250
251
        # If unset max tokens, then generate up to the max_model_len.
        if sampling_params.max_tokens is None:
252
253
254
            sampling_params.max_tokens = (
                self.model_config.max_model_len -
                len(decoder_inputs["prompt_token_ids"]))
255
256
        sampling_params.update_from_generation_config(
            self.generation_config_fields, eos_token_id)
257
258
        sampling_params.update_from_tokenizer(
            self.tokenizer.get_lora_tokenizer(lora_request))
259

260
        # Multimodal related.
261
        sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
262
263
        sorted_mm_positions: Optional[list[PlaceholderRange]] = None
        sorted_mm_hashes: Optional[list[str]] = None
264
265
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
266
267
268
269
270

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            (
271
                sorted_item_modalities,
272
273
274
                sorted_mm_positions,
                sorted_mm_hashes,
            ) = merge_and_sort_multimodal_metadata(
275
276
                decoder_inputs["mm_placeholders"],
                decoder_inputs["mm_hashes"] if self.use_hash else None,
277
            )
278

279
280
281
282
283
284
285
            # The output of merged multi-modal processor (`decoder_mm_inputs`)
            # is a single MultiModalKwargs for all items from all modalities.
            # This code flattens kwargs for individual items in a list and
            # sorts them by each item's position in the input sequence if there
            # are multiple modalities.
            unique_modalities = set(sorted_item_modalities)
            if len(unique_modalities) > 1:
286
                orig_sorted_mm_inputs = []
287
                used_indices = {modality: 0 for modality in unique_modalities}
288

289
290
291
                for modality in sorted_item_modalities:
                    items = decoder_mm_inputs.get_items(modality)
                    item = items[used_indices[modality]]
292
293
294

                    orig_sorted_mm_inputs.append(
                        MultiModalKwargs.from_items([item]))
295
                    used_indices[modality] += 1
296
            else:
297
                orig_sorted_mm_inputs = [
298
299
300
                    MultiModalKwargs.from_items([item]) for item in
                    decoder_mm_inputs.get_items(sorted_item_modalities[0])
                ]
301

302
303
304
305
306
307
            if sorted_mm_hashes is not None:
                sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
                    orig_sorted_mm_inputs, sorted_mm_hashes)
            else:
                sorted_mm_inputs = orig_sorted_mm_inputs

308
        return decoder_inputs.get("prompt"), EngineCoreRequest(
309
            request_id=request_id,
310
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
311
312
313
314
315
316
317
            mm_inputs=sorted_mm_inputs,
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
318
        )
319

320
321
322
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
323
324
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

325
326
327
328
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
329

330
331
332
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
333

334
335
336
337
338
339
340
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
341
        model_config = self.model_config
342
        tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
343

344
345
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
346
347
348
349
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
350

351
        max_input_id = max(prompt_ids, default=0)
352
353
354
355
        if max_input_id > tokenizer.max_token_id:
            raise ValueError(f"Token id {max_input_id} is out of vocabulary")

        max_prompt_len = self.model_config.max_model_len
356
        if len(prompt_ids) > max_prompt_len:
357
358
359
360
361
362
363
364
365
366
367
368
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
                    return  # Skip encoder length check for Whisper

            if model_config.is_multimodal_model:
369
                suggestion = (
370
371
372
373
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
374
375
376
377
378
379
380
381
382
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
383

384
385
386
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens