processor.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
5
from collections.abc import Mapping
from typing import Optional, Union
6

7
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
8
9
10
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
                         PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
11
12
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
13
14
15
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
                             MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
16
17
18
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
19
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
20
from vllm.v1.engine import EngineCoreRequest
21
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
22
23
24
25
26
27
28


class Processor:

    def __init__(
        self,
        model_config: ModelConfig,
29
        cache_config: CacheConfig,
30
        lora_config: Optional[LoRAConfig],
31
        tokenizer: BaseTokenizerGroup,
32
        input_registry: InputRegistry = INPUT_REGISTRY,
33
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
34
35
36
    ):

        self.model_config = model_config
37
        self.cache_config = cache_config
38
39
40
        self.lora_config = lora_config
        self.tokenizer = tokenizer

41
42
        self.generation_config_fields = model_config.try_get_generation_config(
        )
43
        self.input_preprocessor = InputPreprocessor(model_config,
44
45
                                                    self.tokenizer,
                                                    mm_registry)
46
47
48
        self.input_processor = input_registry.create_input_processor(
            model_config)

49
        # Multi-modal (huggingface) input mapper
50
        self.mm_input_cache_client = MMInputCacheClient(model_config)
51
52

        # Multi-modal hasher (for images)
53
        self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
54
            cache_config.enable_prefix_caching
55

56
57
    def _validate_logprobs(
        self,
58
        params: SamplingParams,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    ) -> None:
        max_logprobs = self.model_config.max_logprobs
        # Validate sample logprobs.
        if params.logprobs and params.logprobs > max_logprobs:
            raise ValueError(
                f"Requested sample logprobs of {params.logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # Validate prompt logprobs.
        if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
            raise ValueError(
                f"Requested prompt logprobs of {params.prompt_logprobs}, "
                f"which is greater than max allowed: {max_logprobs}")

        # TODO(andy): enable this in follow up by recomputing.
        if (params.prompt_logprobs is not None
                and self.cache_config.enable_prefix_caching):
            raise ValueError("Prefix caching with prompt logprobs not yet "
                             "supported on VLLM V1.")

79
    def _validate_sampling_params(
80
        self,
81
        params: SamplingParams,
82
83
84
    ) -> None:
        if params.allowed_token_ids is None:
            return
85
86
87
88
        if not params.allowed_token_ids:
            raise ValueError("allowed_token_ids is not None and empty!")
        vocab_size = self.model_config.get_vocab_size()
        if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
89
            raise ValueError(
90
                "allowed_token_ids contains out-of-vocab token id!")
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    def _validate_supported_sampling_params(
        self,
        params: SamplingParams,
    ) -> None:
        # Best of not yet supported.
        if params.best_of:
            raise ValueError("VLLM V1 does not yet support best_of.")
        # Bad words not yet supported.
        if params.bad_words:
            raise ValueError("VLLM V1 does not yet support bad_words.")
        # Logits processors not supported.
        if params.logits_processors:
            raise ValueError("VLLM V1 does not support per request "
                             "user provided logits processors.")

    def _validate_params(
        self,
        params: Union[SamplingParams, PoolingParams],
    ):
        """
        Validate supported SamplingParam.
        Should raise ValueError if unsupported for API Server.
        """

        if not isinstance(params, SamplingParams):
            raise ValueError("V1 does not yet support Pooling models.")

        self._validate_logprobs(params)
        self._validate_sampling_params(params)
        self._validate_supported_sampling_params(params)

    def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")

128
129
130
131
132
    def process_inputs(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
133
        arrival_time: Optional[float] = None,
134
135
136
137
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
138
    ) -> EngineCoreRequest:
139

140
        # TODO(woosuk): Support pooling models.
141
142
        # TODO(woosuk): Support encoder-decoder models.

143
        self._validate_lora(lora_request)
144
145
146
147
148
149
150
        self._validate_params(params)
        if priority != 0:
            raise ValueError("V1 does not support priority yet.")
        if trace_headers is not None:
            raise ValueError("V1 does not support tracing yet.")
        if prompt_adapter_request is not None:
            raise ValueError("V1 does not support prompt_adapter_request.")
151

152
153
154
        if arrival_time is None:
            arrival_time = time.time()

155
156
157
158
159
        # Process inputs, which includes:
        # 1. Tokenize text prompt, with LoRA request if one exists.
        # 2. For multimodal models with a merged preprocessor, preprocess
        #   multimodal data and expand prompt token ids accordingly.
        # 3. Apply prompt adapter to prompt token ids if one exists.
160
161
162
163
164
        preprocessed_inputs = self.input_preprocessor.preprocess(
            prompt,
            request_id=request_id,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
165
            return_mm_hashes=self.use_hash,
166
        )
167
168
169
170
        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

        # Process prompt and prompt token ids.
        # Only applicable to multimodal models with legacy input processor.
171
        processed_inputs = self.input_processor(preprocessed_inputs)
172

173
174
        self._validate_model_inputs(processed_inputs)

175
176
177
178
179
180
181
182
183
184
185
186
187
        if is_encoder_decoder_inputs(processed_inputs):
            decoder_inputs = SingletonInputsAdapter(
                processed_inputs["decoder"])
            encoder_inputs = SingletonInputsAdapter(
                processed_inputs["encoder"])
        else:
            decoder_inputs = SingletonInputsAdapter(processed_inputs)
            encoder_inputs = None

        # TODO: Impl encoder-decoder
        if encoder_inputs is not None:
            raise NotImplementedError

188
189
190
191
192
193
        assert isinstance(params, SamplingParams)
        # TODO: can we avoid cloning here in multiproc case
        sampling_params = params.clone()
        sampling_params.update_from_generation_config(
            self.generation_config_fields, eos_token_id)

194
195
196
197
198
199
200
201
202
203
204
205
        # Multimodal related.
        # Compute MM hashes (if enabled)
        mm_hashes = None
        if self.use_hash:
            # Use mm_hashes from processed inputs if the model has merged
            # input processor.
            if decoder_inputs.multi_modal_hashes:
                mm_hashes = decoder_inputs.multi_modal_hashes
            # Fallback to using MultiModalHasher directly.
            else:
                mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)

206
        # For merged preprocessor, mm_data is already mm_inputs
207
        precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
208
209
210
211
212
213
214
215
216
217
218
        decoder_mm_data = decoder_inputs.multi_modal_data
        if isinstance(decoder_mm_data, MultiModalKwargs):
            # The output of merged multi-modal processor (`decoder_mm_data`)
            # contains the kwargs for all items from all modalities.
            # This code separates them so that there is one set of kwargs
            # per item per modality.
            precomputed_mm_inputs = [
                MultiModalKwargs.from_items([item])
                for modality in decoder_mm_data.modalities
                for item in decoder_mm_data.get_items(modality)
            ]
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        mm_positions = decoder_inputs.multi_modal_placeholders

        # Last-mile processing of multimodal metadata and inputs.
        if mm_positions:

            # Merge and flatten multimodal placeholders, hashes and inputs
            # from dictionaries to lists, and sort them by each item's position
            # in the input sequence.
            # NOTE: interleaved modalities are not supported.
            (
                sorted_modalities,
                sorted_mm_positions,
                sorted_mm_hashes,
            ) = merge_and_sort_multimodal_metadata(
                mm_positions,
235
236
                mm_hashes,
            )
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            # NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
            # modalities involved AND the model supports merged input processor.
            if len(sorted_modalities) > 1 and precomputed_mm_inputs:

                modality_order_dict = {
                    modality: order
                    for order, modality in enumerate(sorted_modalities)
                }

                # Sanity check to make sure each multimodal input has only one
                # modality key.
                for mm_input in precomputed_mm_inputs:
                    assert len(mm_input.modalities) == 1

                # Sort MultiModalKwags to match sorted_mm_positions
                precomputed_mm_inputs = sorted(
                    precomputed_mm_inputs,
                    key=lambda mm_input: modality_order_dict[list(
                        mm_input.modalities)[0]])

258
259
            # Apply mm input cache update and legacy input mapper if one exists.
            sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
260
261
262
263
264
265
266
267
268
269
                mm_data=decoder_mm_data,
                mm_hashes=sorted_mm_hashes,
                mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
                precomputed_mm_inputs=precomputed_mm_inputs,
            )
        else:
            sorted_mm_inputs = None
            sorted_mm_hashes = None
            sorted_mm_positions = None

270
        return EngineCoreRequest(
271
272
273
274
275
276
277
278
279
280
            request_id=request_id,
            prompt=decoder_inputs.prompt,
            prompt_token_ids=decoder_inputs.prompt_token_ids,
            mm_inputs=sorted_mm_inputs,
            mm_hashes=sorted_mm_hashes,
            mm_placeholders=sorted_mm_positions,
            sampling_params=sampling_params,
            eos_token_id=eos_token_id,
            arrival_time=arrival_time,
            lora_request=lora_request,
281
        )
282

283
284
285
286
287
288
289
290
291
292
293
    def _validate_model_inputs(self, inputs: ProcessorInputs):
        if is_encoder_decoder_inputs(inputs):
            # For encoder-decoder multimodal models, the max_prompt_len
            # restricts the decoder prompt length
            prompt_inputs = inputs["decoder" if self.model_config.
                                   is_multimodal_model else "encoder"]
        else:
            prompt_inputs = inputs

        prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids

294
295
296
        if prompt_ids is None or len(prompt_ids) == 0:
            raise ValueError("Prompt cannot be empty")

297
298
299
300
301
        if len(prompt_ids) >= self.model_config.max_model_len:
            raise ValueError(
                f"Prompt length of {len(prompt_ids)} is longer than the "
                f"maximum model length of {self.model_config.max_model_len}.")

302
303
304
305
306
307
308
309
310
311
312
313
        if self.model_config.is_multimodal_model:
            max_prompt_len = self.model_config.max_model_len

            if len(prompt_ids) > max_prompt_len:
                raise ValueError(
                    f"The prompt (total length {len(prompt_ids)}) is too long "
                    f"to fit into the model (context length {max_prompt_len}). "
                    "Make sure that `max_model_len` is no smaller than the "
                    "number of text tokens plus multimodal tokens. For image "
                    "inputs, the number of image tokens depends on the number "
                    "of images, and possibly their aspect ratios as well.")

314
315
316
            # TODO: Find out how many placeholder tokens are there so we can
            # check that chunked prefill does not truncate them
            # max_batch_len = self.scheduler_config.max_num_batched_tokens