processor.py 17.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from typing import Any, Literal, Optional, Union
7

8
from vllm.config import VllmConfig
9
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
10
from vllm.inputs.parse import split_enc_dec_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
14
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
15
from vllm.multimodal.processing import EncDecMultiModalProcessor
16
from vllm.multimodal.utils import argsort_mm_positions
17
18
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
19
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
20
from vllm.v1.engine import EngineCoreRequest
21
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
22
23
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
24
25
from vllm.v1.structured_output.backend_lm_format_enforcer import (
    validate_structured_output_request_lm_format_enforcer)
26
27
from vllm.v1.structured_output.backend_outlines import (
    validate_structured_output_request_outlines)
28
29
from vllm.v1.structured_output.backend_xgrammar import (
    validate_xgrammar_grammar)
30
31
32
33
34
35


class Processor:

    def __init__(
        self,
36
        vllm_config: VllmConfig,
37
        tokenizer: TokenizerGroup,
38
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
39
40
    ):

41
42
43
44
45
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
46
47
        self.tokenizer = tokenizer

48
49
50
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
        self.input_preprocessor = InputPreprocessor(self.model_config,
51
52
                                                    self.tokenizer,
                                                    mm_registry)
53

54
        self.mm_input_cache_client = MultiModalInputCacheClient(
55
            self.model_config, mm_registry)
56

57
58
59
60
    @property
    def mm_registry(self):
        return self.input_preprocessor.mm_registry

61
62
    def _validate_logprobs(
        self,
63
        params: SamplingParams,
64
65
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
66
67
        if max_logprobs == -1:
            return
68
        # Validate sample logprobs.
69
70
        if params.logprobs and (params.logprobs == -1
                                or params.logprobs > max_logprobs):
71
72
73
74
75
76
77
78
79
80
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

81
    def _validate_sampling_params(
82
        self,
83
        params: SamplingParams,
84
        lora_request: Optional[LoRARequest],
85
    ) -> None:
86
        self._validate_structured_output(params)
87
        self._validate_logit_bias(params)
88

89
90
        if params.allowed_token_ids is None:
            return
91
92
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
93
94
95
96
        if self.tokenizer is None:
            # When skip_tokenizer_init=True, we can't validate token IDs
            # Skip validation and let the model handle invalid tokens
            return
97
98
        tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
        vocab_size = len(tokenizer)
99
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
100
            raise ValueError(
101
                "allowed_token_ids contains out-of-vocab token id!")
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def _validate_logit_bias(
        self,
        params: SamplingParams,
    ) -> None:
        """Validate logit_bias token IDs are within vocabulary range."""
        if not params.logit_bias:
            return

        vocab_size = self.model_config.get_vocab_size()
        invalid_token_ids = []

        for token_id in params.logit_bias:
            if token_id < 0 or token_id >= vocab_size:
                invalid_token_ids.append(token_id)

        if invalid_token_ids:
            raise ValueError(
                f"token_id(s) {invalid_token_ids} in logit_bias contain "
                f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

123
124
125
126
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
127
128
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
129
            raise ValueError("vLLM V1 does not yet support best_of.")
130
131
        # Logits processors not supported.
        if params.logits_processors:
132
            raise ValueError("vLLM V1 does not support per request "
133
134
135
136
137
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
138
        lora_request: Optional[LoRARequest],
139
140
141
142
143
144
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

145
146
        if isinstance(params, PoolingParams):
            return
147
148

        self._validate_logprobs(params)
149
        self._validate_sampling_params(params, lora_request)
150
151
152
153
154
155
156
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

157
158
159
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
160

161
162
163
164
165
        if self.model_config.skip_tokenizer_init and params.guided_decoding:
            raise ValueError(
                "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'"  # noqa: E501
            )

166
        engine_level_backend = self.decoding_config.backend
167
        if params.guided_decoding.backend:
168
169
170
171
172
173
            # Request-level backend selection is not supported in V1.
            # The values may differ if `params` is reused and was set
            # to a specific backend based on `auto` behavior in a previous
            # request. We remember that it was set as a result of `auto`
            # using the `_auto` option set on the backend in the params.
            if (params.guided_decoding.backend != engine_level_backend
174
175
                    and not (engine_level_backend == "auto"
                             and params.guided_decoding.backend_was_auto)):
176
177
178
179
180
181
182
                raise ValueError(
                    "Request-level structured output backend selection is no "
                    "longer supported. The request specified "
                    f"'{params.guided_decoding.backend}', but vLLM was "
                    f"initialised with '{engine_level_backend}'. This error "
                    "can be resolved by removing backend selection from the "
                    "request.")
183
184
        else:
            params.guided_decoding.backend = engine_level_backend
185

186
        # Request content validation
187
188
189
190
191
192
        if (isinstance(params.guided_decoding.choice, list)
                and not params.guided_decoding.choice):
            # It is invalid for choice to be an empty list
            raise ValueError(f"Choice '{params.guided_decoding.choice}' "
                             "cannot be an empty list")

193
        if engine_level_backend.startswith("xgrammar"):
194
            # xgrammar with no fallback
195
            validate_xgrammar_grammar(params)
196
197
198
199
200
201
        elif engine_level_backend.startswith("guidance"):
            # TODO: ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
202
203
204
        elif engine_level_backend == "outlines":
            # outlines backend
            validate_structured_output_request_outlines(params)
205
206
207
        elif engine_level_backend == "lm-format-enforcer":
            # lm format enforcer backend
            validate_structured_output_request_lm_format_enforcer(params)
208
209
210
        else:
            # NOTE: engine_level_backend must be "auto" here, because we have
            # checked supported_backends above.
211
212
213
214
215
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
216
                validate_xgrammar_grammar(params)
217
218
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
219
220
                # The request either failed validation
                # or includes some jsonschema feature(s) that
221
                # are not supported in xgrammar. Fall back to guidance.
222
                validate_guidance_grammar(params, tokenizer=None)
223
                params.guided_decoding.backend = "guidance"
224
            # Remember that this backend was set automatically
225
            params.guided_decoding.backend_was_auto = True
226

227
228
229
230
231
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
232
        arrival_time: Optional[float] = None,
233
        lora_request: Optional[LoRARequest] = None,
234
        tokenization_kwargs: Optional[dict[str, Any]] = None,
235
236
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
237
        data_parallel_rank: Optional[int] = None,
238
    ) -> tuple[Optional[str], EngineCoreRequest]:
239

240
        # TODO(woosuk): Support pooling models.
241
        # TODO(woosuk): Support encoder-decoder models.
242
        self._validate_lora(lora_request)
243
        self._validate_params(params, lora_request)
244
245
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
246

247
248
249
250
251
252
        data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
        if data_parallel_rank is not None and not (0 <= data_parallel_rank <
                                                   data_parallel_size):
            raise ValueError(f"data_parallel_rank {data_parallel_rank} "
                             f"is out of range [0, {data_parallel_size}).")

253
254
255
        if arrival_time is None:
            arrival_time = time.time()

256
257
258
259
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
260
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
261
            prompt,
262
            tokenization_kwargs=tokenization_kwargs,
263
264
            lora_request=lora_request,
        )
265
266
267
268
269
270
        from vllm.platforms import current_platform
        current_platform.validate_request(
            prompt=prompt,
            params=params,
            processed_inputs=processed_inputs,
        )
271
272
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

273
        self._validate_model_inputs(processed_inputs, lora_request)
274

275
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
276
277
278
279
280

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

281
282
283
284
285
286
287
288
289
290
291
292
        sampling_params = None
        pooling_params = None
        if isinstance(params, SamplingParams):
            # TODO: can we avoid cloning here in multiproc case?
            sampling_params = params.clone()
            # If unset max tokens, then generate up to the max_model_len.
            if sampling_params.max_tokens is None:
                sampling_params.max_tokens = (
                    self.model_config.max_model_len -
                    len(decoder_inputs["prompt_token_ids"]))
            sampling_params.update_from_generation_config(
                self.generation_config_fields, eos_token_id)
293
294
295
            if self.tokenizer is not None:
                sampling_params.update_from_tokenizer(
                    self.tokenizer.get_lora_tokenizer(lora_request))
296
297
        else:
            pooling_params = params.clone()
298

299
        # Multimodal related.
300
        sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
301
302
        sorted_mm_positions: Optional[list[PlaceholderRange]] = None
        sorted_mm_hashes: Optional[list[str]] = None
303
304
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
305
            decoder_mm_positions = decoder_inputs["mm_placeholders"]
306
            decoder_mm_hashes = decoder_inputs["mm_hashes"]
307
308
309
310

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
311
312
            sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)

313
            orig_sorted_mm_inputs = [
314
                decoder_mm_inputs[modality][idx]
315
316
317
318
319
320
                for modality, idx in sorted_mm_idxs
            ]
            sorted_mm_positions = [
                decoder_mm_positions[modality][idx]
                for modality, idx in sorted_mm_idxs
            ]
321
            sorted_mm_hashes = [
322
323
324
                decoder_mm_hashes[modality][idx]
                for modality, idx in sorted_mm_idxs
            ]
325

326
327
328
329
            sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
                orig_sorted_mm_inputs,
                sorted_mm_hashes,
            )
330

331
        return decoder_inputs.get("prompt"), EngineCoreRequest(
332
            request_id=request_id,
333
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
334
            mm_kwargs=sorted_mm_inputs,
335
336
337
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
338
            pooling_params=pooling_params,
339
340
341
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
342
            cache_salt=decoder_inputs.get("cache_salt"),
343
            priority=priority,
344
            data_parallel_rank=data_parallel_rank,
345
        )
346

347
348
349
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
350
351
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

352
353
354
355
        if encoder_inputs is not None:
            self._validate_model_input(encoder_inputs,
                                       lora_request,
                                       prompt_type="encoder")
356

357
358
359
        self._validate_model_input(decoder_inputs,
                                   lora_request,
                                   prompt_type="decoder")
360

361
362
363
364
365
366
367
    def _validate_model_input(
        self,
        prompt_inputs: SingletonInputs,
        lora_request: Optional[LoRARequest],
        *,
        prompt_type: Literal["encoder", "decoder"],
    ):
368
        model_config = self.model_config
369

370
371
        prompt_ids = prompt_inputs["prompt_token_ids"]
        if not prompt_ids:
372
373
374
375
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                pass  # Mllama may have empty encoder inputs for text-only data
            else:
                raise ValueError(f"The {prompt_type} prompt cannot be empty")
376

377
378
379
380
381
382
383
384
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
            max_input_id = max(prompt_ids, default=0)
            if max_input_id > tokenizer.max_token_id:
                raise ValueError(
                    f"Token id {max_input_id} is out of vocabulary")
385
386

        max_prompt_len = self.model_config.max_model_len
387
        if len(prompt_ids) > max_prompt_len:
388
389
390
391
392
393
394
395
396
            if prompt_type == "encoder" and model_config.is_multimodal_model:
                mm_registry = self.input_preprocessor.mm_registry
                mm_processor = mm_registry.create_processor(
                    model_config,
                    tokenizer=tokenizer,
                )
                assert isinstance(mm_processor, EncDecMultiModalProcessor)

                if mm_processor.pad_dummy_encoder_prompt:
汪志鹏's avatar
汪志鹏 committed
397
                    return  # Skip encoder length check for Whisper and Donut
398
399

            if model_config.is_multimodal_model:
400
                suggestion = (
401
402
403
404
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")
405
406
407
408
409
410
411
412
413
            else:
                suggestion = (
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens.")

            raise ValueError(
                f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
                f"longer than the maximum model length of {max_prompt_len}. "
                f"{suggestion}")
414

415
416
417
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens