processor.py 14.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
5
from collections.abc import Mapping
from typing import Optional, Union
6

7
from vllm.config import VllmConfig
8
9
10
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
                         PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
14
15
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
                             MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
16
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
17
18
19
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
21
from vllm.v1.engine import EngineCoreRequest
22
23
24
25
from vllm.v1.structured_output.backend_guidance import (
    validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
    validate_structured_output_request_xgrammar)
26
27
28
29
30
31


class Processor:

    def __init__(
        self,
32
        vllm_config: VllmConfig,
33
        tokenizer: BaseTokenizerGroup,
34
        input_registry: InputRegistry = INPUT_REGISTRY,
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
37
    ):

38
39
40
41
42
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.decoding_config = vllm_config.decoding_config
43
44
        self.tokenizer = tokenizer

45
46
47
        self.generation_config_fields = (
            self.model_config.try_get_generation_config())
        self.input_preprocessor = InputPreprocessor(self.model_config,
48
49
                                                    self.tokenizer,
                                                    mm_registry)
50
51

        # Multi-modal hasher (for images)
52
53
54
        self.use_hash = (
            not self.model_config.disable_mm_preprocessor_cache) or \
            self.cache_config.enable_prefix_caching
55

56
57
    def _validate_logprobs(
        self,
58
        params: SamplingParams,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        # Validate sample logprobs.
        if params.logprobs and params.logprobs > max_logprobs:
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

73
    def _validate_sampling_params(
74
        self,
75
        params: SamplingParams,
76
    ) -> None:
77
78
        self._validate_structured_output(params)

79
80
        if params.allowed_token_ids is None:
            return
81
82
83
84
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        vocab_size = self.model_config.get_vocab_size()
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
85
            raise ValueError(
86
                "allowed_token_ids contains out-of-vocab token id!")
87

88
89
90
91
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
92
93
        # Best of not yet supported.
        if params.best_of is not None and params.best_of > 1:
94
            raise ValueError("vLLM V1 does not yet support best_of.")
95
96
        # Logits processors not supported.
        if params.logits_processors:
97
            raise ValueError("vLLM V1 does not support per request "
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

        if not isinstance(params, SamplingParams):
            raise ValueError("V1 does not yet support Pooling models.")

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

121
122
123
    def _validate_structured_output(self, params: SamplingParams) -> None:
        if not params.guided_decoding or not self.decoding_config:
            return
124

125
126
127
        supported_backends = [
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
        ]
128
129
130
131
132
133
134
135
136
137
138
139
140
        engine_level_backend = self.decoding_config.guided_decoding_backend
        if engine_level_backend not in supported_backends:
            raise ValueError(f"Only {supported_backends} structured output is "
                             "supported in V1.")
        if params.guided_decoding.backend:
            if params.guided_decoding.backend != engine_level_backend:
                raise ValueError("Request-level structured output backend "
                                 "must match engine-level backend. "
                                 f"{params.guided_decoding.backend}"
                                 f" != {engine_level_backend}")
        else:
            params.guided_decoding.backend = engine_level_backend

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        # Request content validation

        if engine_level_backend == "xgrammar":
            # xgrammar with no fallback
            validate_structured_output_request_xgrammar(params)
            params.guided_decoding.backend = "xgrammar"
        elif engine_level_backend == "auto":
            # "auto" is an opt-in to opinionated behavior where we try to
            # choose a backend based on request contents. This is not the
            # default as it is less predictable and subject to change
            # between releases as feature support changes.
            try:
                validate_structured_output_request_xgrammar(params)
                params.guided_decoding.backend = "xgrammar"
            except ValueError:
                # The request includes some jsonschema feature(s) that
                # are not supported in xgrammar. Fall back to guidance.
                params.guided_decoding.backend = "guidance"

        if params.guided_decoding.backend == "guidance":
            # TODO ideally we would have the LLTokenizer here as Lark syntax
            # allows <|special_token|> and similar, see
            # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
            # Without tokenizer these are disallowed in grammars.
            validate_guidance_grammar(params, tokenizer=None)
166

167
168
169
170
171
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
172
        arrival_time: Optional[float] = None,
173
174
175
176
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
177
    ) -> EngineCoreRequest:
178

179
        # TODO(woosuk): Support pooling models.
180
181
        # TODO(woosuk): Support encoder-decoder models.

182
        self._validate_lora(lora_request)
183
184
185
186
187
188
189
        self._validate_params(params)
        if priority != 0:
            raise ValueError("V1 does not support priority yet.")
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
        if prompt_adapter_request is not None:
            raise ValueError("V1 does not support prompt_adapter_request.")
190

191
192
193
        if arrival_time is None:
            arrival_time = time.time()

194
195
196
197
198
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
        # 3. Apply prompt adapter to prompt token ids if one exists.
199
        processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
200
201
202
            prompt,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
203
            return_mm_hashes=self.use_hash,
204
        )
205
206
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

207
        self._validate_model_inputs(processed_inputs, lora_request)
208

209
210
211
212
213
214
215
216
217
218
219
220
221
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = SingletonInputsAdapter(
                processed_inputs["decoder"])
            encoder_inputs = SingletonInputsAdapter(
                processed_inputs["encoder"])
        else:
            decoder_inputs = SingletonInputsAdapter(processed_inputs)
            encoder_inputs = None

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

222
        assert isinstance(params, SamplingParams)
223
        # TODO: can we avoid cloning here in multiproc case?
224
        sampling_params = params.clone()
225
226
227
228
        # If unset max tokens, then generate up to the max_model_len.
        if sampling_params.max_tokens is None:
            sampling_params.max_tokens = (self.model_config.max_model_len -
                                          len(decoder_inputs.prompt_token_ids))
229
230
        sampling_params.update_from_generation_config(
            self.generation_config_fields, eos_token_id)
231
232
        sampling_params.update_from_tokenizer(
            self.tokenizer.get_lora_tokenizer(lora_request))
233

234
        # Multimodal related.
235
236
237
238
239
        sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
        sorted_mm_positions: Optional[list[PlaceholderRange]] = None
        sorted_mm_hashes: Optional[list[str]] = None
        if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
            assert isinstance(decoder_mm_inputs, MultiModalKwargs)
240

241
            # The output of merged multi-modal processor (`decoder_mm_inputs`)
242
243
244
            # contains the kwargs for all items from all modalities.
            # This code separates them so that there is one set of kwargs
            # per item per modality.
245
            individual_mm_inputs = [
246
                MultiModalKwargs.from_items([item])
247
248
                for modality in decoder_mm_inputs.modalities
                for item in decoder_mm_inputs.get_items(modality)
249
            ]
250

251
252
253
254
255
256
257
258
259
            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            # NOTE: interleaved modalities are not supported.
            (
                sorted_modalities,
                sorted_mm_positions,
                sorted_mm_hashes,
            ) = merge_and_sort_multimodal_metadata(
260
261
                decoder_inputs.multi_modal_placeholders,
                decoder_inputs.multi_modal_hashes if self.use_hash else None,
262
            )
263

264
            # NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
265
266
            # modalities involved.
            if len(sorted_modalities) > 1:
267
268
269
270
271
272
273
                modality_order_dict = {
                    modality: order
                    for order, modality in enumerate(sorted_modalities)
                }

                # Sanity check to make sure each multimodal input has only one
                # modality key.
274
                for mm_input in individual_mm_inputs:
275
276
                    assert len(mm_input.modalities) == 1

277
278
279
                # Sort MultiModalKwargs to match sorted_mm_positions
                sorted_mm_inputs = sorted(
                    individual_mm_inputs,
280
281
                    key=lambda mm_input: modality_order_dict[list(
                        mm_input.modalities)[0]])
282
283
            else:
                sorted_mm_inputs = individual_mm_inputs
284

285
        return EngineCoreRequest(
286
287
288
289
290
291
292
293
294
295
            request_id=request_id,
            prompt=decoder_inputs.prompt,
            prompt_token_ids=decoder_inputs.prompt_token_ids,
            mm_inputs=sorted_mm_inputs,
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
296
        )
297

298
299
300
    def _validate_model_inputs(self,
                               inputs: ProcessorInputs,
                               lora_request: Optional[LoRARequest] = None):
301
302
303
304
305
306
307
308
309
310
        if is_encoder_decoder_inputs(inputs):
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
        else:
            prompt_inputs = inputs

        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids

311
312
313
        if prompt_ids is None or len(prompt_ids) == 0:
            raise ValueError("Prompt cannot be empty")

314
315
316
317
318
319
320
        max_input_id = max(prompt_ids)
        max_allowed = self.tokenizer.get_lora_tokenizer(
            lora_request).max_token_id
        if max_input_id > max_allowed:
            raise ValueError(
                "Token id {} is out of vocabulary".format(max_input_id))

321
322
323
324
325
        if len(prompt_ids) >= self.model_config.max_model_len:
            raise ValueError(
                f"Prompt length of {len(prompt_ids)} is longer than the "
                f"maximum model length of {self.model_config.max_model_len}.")

326
327
328
329
330
331
332
333
334
335
336
337
        if self.model_config.is_multimodal_model:
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")

338
339
340
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens