processor.py 14.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
5
from collections.abc import Mapping
from typing import Optional, Union
6

7
from vllm.config import VllmConfig
8
from vllm.inputs import ProcessorInputs, PromptType
9
from vllm.inputs.parse import split_enc_dec_inputs
10
11
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
12
13
14
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
                             MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
15
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
16
17
18
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
19
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
20
from vllm.v1.engine import EngineCoreRequest
21
22
23
24
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
    validate_structured_output_request_xgrammar)
25
26
27
28
29
30


class Processor:

    def __init__(
        self,
31
        vllm_config: VllmConfig,
32
33
        tokenizer: BaseTokenizerGroup,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
34
35
    ):

36
37
38
39
40
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
41
42
        self.tokenizer = tokenizer

43
44
45
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
        self.input_preprocessor = InputPreprocessor(self.model_config,
46
47
                                                    self.tokenizer,
                                                    mm_registry)
48
49

        # Multi-modal hasher (for images)
50
51
52
        self.use_hash = (
            not self.model_config.disable_mm_preprocessor_cache) or \
            self.cache_config.enable_prefix_caching
53

54
55
    def _validate_logprobs(
        self,
56
        params: SamplingParams,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        # Validate sample logprobs.
        if params.logprobs and params.logprobs > max_logprobs:
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

71
    def _validate_sampling_params(
72
        self,
73
        params: SamplingParams,
74
    ) -> None:
75
76
        self._validate_structured_output(params)

77
78
        if params.allowed_token_ids is None:
            return
79
80
81
82
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        vocab_size = self.model_config.get_vocab_size()
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
83
            raise ValueError(
84
                "allowed_token_ids contains out-of-vocab token id!")
85

86
87
88
89
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
90
91
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
92
            raise ValueError("vLLM V1 does not yet support best_of.")
93
94
        # Logits processors not supported.
        if params.logits_processors:
95
            raise ValueError("vLLM V1 does not support per request "
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

        if not isinstance(params, SamplingParams):
            raise ValueError("V1 does not yet support Pooling models.")

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

119
120
121
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
122

123
        supported_backends = [
124
125
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
            "guidance:disable-any-whitespace", "auto"
126
        ]
127
128
129
130
131
132
133
134
135
136
137
138
        engine_level_backend = self.decoding_config.guided_decoding_backend
        if engine_level_backend not in supported_backends:
            raise ValueError(f"Only {supported_backends} structured output is "
                             "supported in V1.")
        if params.guided_decoding.backend:
            if params.guided_decoding.backend != engine_level_backend:
                raise ValueError("Request-level structured output backend "
                                 "must match engine-level backend. "
                                 f"{params.guided_decoding.backend}"
                                 f" != {engine_level_backend}")
        else:
            params.guided_decoding.backend = engine_level_backend
139
140
141
142
143

        from vllm.platforms import current_platform
        if not current_platform.supports_structured_output():
            raise ValueError("Structured output is not supported on "
                             f"{current_platform.device_name}.")
144

145
        # Request content validation
146
        if engine_level_backend.startswith("xgrammar"):
147
148
            # xgrammar with no fallback
            validate_structured_output_request_xgrammar(params)
149
            params.guided_decoding.backend = engine_level_backend
150
151
152
153
154
155
156
157
158
159
160
161
162
        elif engine_level_backend == "auto":
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
                validate_structured_output_request_xgrammar(params)
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
                # The request includes some jsonschema feature(s) that
                # are not supported in xgrammar. Fall back to guidance.
                params.guided_decoding.backend = "guidance"

163
        if engine_level_backend.startswith("guidance"):
164
165
166
167
168
            # TODO ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
169
            params.guided_decoding.backend = engine_level_backend
170

171
172
173
174
175
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
176
        arrival_time: Optional[float] = None,
177
178
179
180
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
181
    ) -> EngineCoreRequest:
182

183
        # TODO(woosuk): Support pooling models.
184
185
        # TODO(woosuk): Support encoder-decoder models.

186
        self._validate_lora(lora_request)
187
188
189
190
191
192
193
        self._validate_params(params)
        if priority != 0:
            raise ValueError("V1 does not support priority yet.")
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
        if prompt_adapter_request is not None:
            raise ValueError("V1 does not support prompt_adapter_request.")
194

195
196
197
        if arrival_time is None:
            arrival_time = time.time()

198
199
200
201
202
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
        # 3. Apply prompt adapter to prompt token ids if one exists.
203
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
204
205
206
            prompt,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
207
            return_mm_hashes=self.use_hash,
208
        )
209
210
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

211
        self._validate_model_inputs(processed_inputs, lora_request)
212

213
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
214
215
216
217
218

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

219
        assert isinstance(params, SamplingParams)
220
        # TODO: can we avoid cloning here in multiproc case?
221
        sampling_params = params.clone()
222
223
        # If unset max tokens, then generate up to the max_model_len.
        if sampling_params.max_tokens is None:
224
225
226
            sampling_params.max_tokens = (
                self.model_config.max_model_len -
                len(decoder_inputs["prompt_token_ids"]))
227
228
        sampling_params.update_from_generation_config(
            self.generation_config_fields, eos_token_id)
229
230
        sampling_params.update_from_tokenizer(
            self.tokenizer.get_lora_tokenizer(lora_request))
231

232
        # Multimodal related.
233
234
235
        sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
        sorted_mm_positions: Optional[list[PlaceholderRange]] = None
        sorted_mm_hashes: Optional[list[str]] = None
236
237
        if decoder_inputs["type"] == "multimodal":
            decoder_mm_inputs = decoder_inputs["mm_kwargs"]
238
239
240
241
242

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            (
243
                sorted_item_modalities,
244
245
246
                sorted_mm_positions,
                sorted_mm_hashes,
            ) = merge_and_sort_multimodal_metadata(
247
248
                decoder_inputs["mm_placeholders"],
                decoder_inputs["mm_hashes"] if self.use_hash else None,
249
            )
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            # The output of merged multi-modal processor (`decoder_mm_inputs`)
            # is a single MultiModalKwargs for all items from all modalities.
            # This code flattens kwargs for individual items in a list and
            # sorts them by each item's position in the input sequence if there
            # are multiple modalities.
            unique_modalities = set(sorted_item_modalities)
            if len(unique_modalities) > 1:
                sorted_mm_inputs = []
                used_indices = {modality: 0 for modality in unique_modalities}
                for modality in sorted_item_modalities:
                    items = decoder_mm_inputs.get_items(modality)
                    item = items[used_indices[modality]]
                    sorted_mm_inputs.append(MultiModalKwargs.from_items([item
                                                                         ]))
                    used_indices[modality] += 1
266
            else:
267
268
269
270
                sorted_mm_inputs = [
                    MultiModalKwargs.from_items([item]) for item in
                    decoder_mm_inputs.get_items(sorted_item_modalities[0])
                ]
271

272
        return EngineCoreRequest(
273
            request_id=request_id,
274
275
            prompt=decoder_inputs.get("prompt"),
            prompt_token_ids=decoder_inputs["prompt_token_ids"],
276
277
278
279
280
281
282
            mm_inputs=sorted_mm_inputs,
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
283
        )
284

285
286
287
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
288
289
290
291
292
293
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)

        # For encoder-decoder multimodal models, the max_prompt_len
        # restricts the decoder prompt length
        if self.model_config.is_multimodal_model:
            prompt_inputs = decoder_inputs
294
        else:
295
            prompt_inputs = encoder_inputs or decoder_inputs
296

297
        prompt_ids = prompt_inputs["prompt_token_ids"]
298

299
300
301
        if prompt_ids is None or len(prompt_ids) == 0:
            raise ValueError("Prompt cannot be empty")

302
303
304
305
306
307
308
        max_input_id = max(prompt_ids)
        max_allowed = self.tokenizer.get_lora_tokenizer(
            lora_request).max_token_id
        if max_input_id > max_allowed:
            raise ValueError(
                "Token id {} is out of vocabulary".format(max_input_id))

309
310
311
312
313
        if len(prompt_ids) >= self.model_config.max_model_len:
            raise ValueError(
                f"Prompt length of {len(prompt_ids)} is longer than the "
                f"maximum model length of {self.model_config.max_model_len}.")

314
315
316
317
318
319
320
321
322
323
324
325
        if self.model_config.is_multimodal_model:
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")

326
327
328
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens