preprocess.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, overload
6
7
8

from typing_extensions import assert_never

9
from vllm.config import ModelConfig, ObservabilityConfig
10
from vllm.logger import init_logger
11
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
12
from vllm.multimodal.cache import BaseMultiModalProcessorCache
13
14
15
16
17
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalInputs,
    MultiModalUUIDDict,
)
18
from vllm.multimodal.processing import BaseMultiModalProcessor
19
from vllm.renderers import renderer_from_config
20
21
22
23
24
25
26
27
28
29
from vllm.renderers.inputs import (
    DecoderDictPrompt,
    DecoderOnlyDictPrompt,
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDictPrompt,
    SingletonDictPrompt,
    TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
30
from vllm.tokenizers import TokenizerLike
31
from vllm.utils.jsontree import json_iter_leaves
32
from vllm.v1.metrics.stats import MultiModalCacheStats
33

34
from .data import (
35
    DecoderInputs,
36
37
38
39
    DecoderOnlyInputs,
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
40
    EncoderInputs,
41
42
43
44
45
46
47
48
49
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
    embeds_inputs,
    token_inputs,
)
50
51
52
53
54
55
56

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
57
        model_config: ModelConfig,
58
        observability_config: ObservabilityConfig | None = None,
59
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
60
        mm_processor_cache: BaseMultiModalProcessorCache | None = None,
61
62
63
    ) -> None:
        super().__init__()

64
        self.model_config = model_config
65
        self.observability_config = observability_config
66
        self.renderer = renderer_from_config(model_config)
67
        self.mm_registry = mm_registry
68
        self.mm_processor_cache = mm_processor_cache
69

70
71
        self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None

72
73
74
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
75

76
77
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
78

79
    def get_bos_token_id(self) -> int | None:
80
        if self.tokenizer is None:
81
            logger.warning_once(
82
83
                "Using None for BOS token id because tokenizer is not initialized"
            )
84
85
            return None

86
        return self.tokenizer.bos_token_id
87

88
    def get_eos_token_id(self) -> int | None:
89
        if self.tokenizer is None:
90
            logger.warning_once(
91
92
                "Using None for EOS token id because tokenizer is not initialized"
            )
93
94
            return None

95
        return self.tokenizer.eos_token_id
96

97
    def get_decoder_start_token_id(self) -> int:
98
        """
99
        Obtain the decoder start token id employed by an encoder/decoder
100
        model. Raises an error if it is not available.
101
        """
102
103
104
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )
105

106
        if dec_start_token_id is None:
107
108
109
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
110
111
                "available."
            )
112
113
            dec_start_token_id = self.get_bos_token_id()

114
115
        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")
116

117
        return dec_start_token_id
118

119
    def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
120
121
122
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

123
124
125
126
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
127
128
129
130
131
132
133
134
135
136
137

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """
        decoder_start_token_id = self.get_decoder_start_token_id()

138
139
140
141
        if (
            len(decoder_input_ids) == 0
            or decoder_input_ids[0] != decoder_start_token_id
        ):
142
143
144
145
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

146
147
    def _get_tokenization_kw(
        self,
148
        overrides: dict[str, Any] | None = None,
149
150
151
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

152
        if self.model_config.is_encoder_decoder:
153
154
155
156
157
158
159
160
161
162
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

163
164
165
    def _tokenize_prompt(
        self,
        prompt: str,
166
        tokenization_kwargs: dict[str, Any] | None = None,
167
    ) -> list[int]:
168
169
170
171
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
172
        tokenizer = self.get_tokenizer()
173
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
174

175
        encoder_config = self.model_config.encoder_config
176

177
        if encoder_config and encoder_config.get("do_lower_case", False):
178
179
            prompt = prompt.lower()

180
        return tokenizer.encode(prompt, **tokenization_kwargs)
181

182
183
184
    def _get_mm_processor(self) -> BaseMultiModalProcessor:
        if not hasattr(self, "_mm_processor"):
            self._mm_processor = self.mm_registry.create_processor(
185
                self.model_config,
186
                self.observability_config,
187
                tokenizer=self.tokenizer,
188
189
190
191
                cache=self.mm_processor_cache,
            )

        return self._mm_processor
192

193
194
    def _process_multimodal(
        self,
195
        prompt: str | list[int],
196
        mm_data: MultiModalDataDict,
197
198
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None = None,
199
        *,
200
        mm_uuids: MultiModalUUIDDict | None = None,
201
    ) -> MultiModalInputs:
202
203
204
205
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
206
        mm_processor = self._get_mm_processor()
207

208
209
210
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

211
        mm_items = mm_processor.info.parse_mm_data(mm_data)
212
        mm_input = mm_processor.apply(
213
            prompt,
214
            mm_items,
215
216
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
217
            mm_uuids=mm_uuids,
218
        )
219
220
221
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
222
223
224
225
        contains_only_strings = all(
            isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
        )
        if not contains_only_strings:
226
227
228
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
229
230
                "MultiModalProcessor.apply method."
            )
231
232

        return mm_input
233

234
235
236
237
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
238
        if not self.model_config.enable_prompt_embeds:
239
240
241
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )
242
243

        prompt_embeds = parsed_content["prompt_embeds"]
244

245
246
247
248
249
250
251
252
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
253
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
254

255
256
257
258
259
        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

260
261
262
        return embeds_inputs(
            prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
        )
263

264
    def _truncate_inputs(
265
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
266
267
268
269
270
271
    ) -> list[int]:
        if (
            not tokenization_kwargs
            or "truncation" not in tokenization_kwargs
            or self.tokenizer is None
        ):
272
273
274
275
276
277
278
279
280
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

281
282
283
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
284
        tokenization_kwargs: dict[str, Any] | None = None,
285
        *,
286
287
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
288
        prompt_token_ids = self._truncate_inputs(
289
290
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
291

292
        inputs: TokenInputs | MultiModalInputs
293
        if multi_modal_data := parsed_content.get("multi_modal_data"):
294
295
            inputs = self._process_multimodal(
                prompt_token_ids,
296
                multi_modal_data,
297
                parsed_content.get("mm_processor_kwargs") or {},
298
                tokenization_kwargs=tokenization_kwargs,
299
                mm_uuids=mm_uuids,
300
            )
301
        else:
302
            inputs = token_inputs(prompt_token_ids)
303
304
305
306
307
308
309
310
311

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
312
        tokenization_kwargs: dict[str, Any] | None = None,
313
        *,
314
315
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
316
317
        prompt_text = parsed_content["prompt"]

318
        inputs: TokenInputs | MultiModalInputs
319
        if multi_modal_data := parsed_content.get("multi_modal_data"):
320
321
            inputs = self._process_multimodal(
                prompt_text,
322
                multi_modal_data,
323
                parsed_content.get("mm_processor_kwargs") or {},
324
                tokenization_kwargs=tokenization_kwargs,
325
                mm_uuids=mm_uuids,
326
327
328
329
330
331
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
332
            inputs = token_inputs(prompt_token_ids)
333
334
335
336
337

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
338

339
    @overload
340
    def _prompt_to_llm_inputs(
341
        self,
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> EncoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderOnlyInputs: ...

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
369
        tokenization_kwargs: dict[str, Any] | None = None,
370
        *,
371
        mm_uuids: MultiModalUUIDDict | None = None,
372
    ) -> SingletonInputs:
373
374
        """
        Extract the singleton inputs from a prompt.
375
376
377

        Arguments:

378
        * prompt: single encoder or decoder input prompt
379
380
381

        Returns:

382
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
383
        """
384
385
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]
386

387
        if "prompt_token_ids" in prompt:
388
            return self._process_tokens(
389
                prompt,  # type: ignore[arg-type]
390
                mm_uuids=mm_uuids,
391
            )
392
393

        if "prompt" in prompt:
394
            return self._process_text(
395
                prompt,  # type: ignore[arg-type]
396
                tokenization_kwargs=tokenization_kwargs,
397
                mm_uuids=mm_uuids,
398
            )
399

400
        assert_never(prompt)  # type: ignore[arg-type]
401

402
    def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
403
        if inputs["type"] == "embeds":
404
405
406
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
407

408
409
410
411
        if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
            raise RuntimeError(
                "You should register an encoder-decoder "
                "multi-modal processor for encoder-decoder models."
412
            )
413

414
        return inputs  # type: ignore[return-value]
415

416
    def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
417
        if inputs["type"] == "embeds":
418
419
420
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
421

422
        return inputs
423

424
425
426
427
428
429
    def _build_enc_dec_inputs(
        self,
        encoder_inputs: SingletonInputs,
        decoder_inputs: SingletonInputs | None = None,
    ) -> EncoderDecoderInputs:
        enc_inputs = self._validate_enc_inputs(encoder_inputs)
430

431
432
433
434
435
436
437
        if decoder_inputs is None:
            dec_inputs: DecoderInputs = enc_inputs  # type: ignore[assignment]
        else:
            dec_inputs = self._validate_dec_inputs(decoder_inputs)

        enc_inputs_new: EncoderInputs
        dec_inputs_new: DecoderInputs
438

439
440
441
        if enc_inputs["type"] == "multimodal":
            enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
            dec_inputs_new = MultiModalInputs(
442
                type="multimodal",
443
444
445
446
                prompt_token_ids=dec_inputs["prompt_token_ids"],
                mm_kwargs=enc_inputs["mm_kwargs"],
                mm_hashes=enc_inputs["mm_hashes"],
                mm_placeholders=enc_inputs["mm_placeholders"],
447
            )
448
449
450
        elif enc_inputs["type"] == "token":
            enc_inputs_new = token_inputs(prompt_token_ids=[])
            dec_inputs_new = dec_inputs
451
        else:
452
453
454
455
456
457
458
            assert_never(enc_inputs)

        dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
            dec_inputs_new["prompt_token_ids"]
        )
        if cache_salt := enc_inputs.get("cache_salt"):
            dec_inputs_new["cache_salt"] = cache_salt
459

460
        return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
461

462
463
    def _process_encoder_decoder_prompt(
        self,
464
        prompt: EncoderDecoderDictPrompt,
465
        tokenization_kwargs: dict[str, Any] | None = None,
466
        *,
467
        mm_uuids: MultiModalUUIDDict | None = None,
468
    ) -> EncoderDecoderInputs:
469
        """
470
        For encoder/decoder models only:
471
472
473
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
474
475
476

        Arguments:

477
        * prompt: an input prompt
478
479
480

        Returns:

481
482
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
483
        """
484
485
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]
486
487
488
489

        return self._build_enc_dec_inputs(
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
490
                tokenization_kwargs=tokenization_kwargs,
491
                mm_uuids=mm_uuids,
492
493
494
495
496
497
498
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
499
                )
500
501
            ),
        )
502
503
504

    def _process_decoder_only_prompt(
        self,
505
        prompt: DecoderOnlyDictPrompt,
506
        tokenization_kwargs: dict[str, Any] | None = None,
507
        *,
508
        mm_uuids: MultiModalUUIDDict | None = None,
509
    ) -> DecoderOnlyInputs:
510
        """
511
        For decoder-only models:
512
513
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
514
515
516

        Arguments:

517
        * prompt: input prompt
518
519
520

        Returns:

521
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
522
        """
523
        return self._prompt_to_llm_inputs(
524
            prompt,
525
            tokenization_kwargs=tokenization_kwargs,
526
            mm_uuids=mm_uuids,
527
528
        )

529
    def _preprocess(
530
        self,
531
        prompt: PromptType | DictPrompt | TokPrompt,
532
        tokenization_kwargs: dict[str, Any] | None = None,
533
        *,
534
        mm_uuids: MultiModalUUIDDict | None = None,
535
    ) -> ProcessorInputs:
536
        if self.model_config.is_encoder_decoder:
537
            # Encoder-decoder model requires special mapping of
538
            # input prompts to encoder & decoder.
539
            return self._process_encoder_decoder_prompt(
540
                parse_enc_dec_prompt(prompt),
541
                tokenization_kwargs,
542
                mm_uuids=mm_uuids,
543
            )
544
545

        return self._process_decoder_only_prompt(
546
            parse_dec_only_prompt(prompt),
547
            tokenization_kwargs=tokenization_kwargs,
548
            mm_uuids=mm_uuids,
549
550
        )

551
552
    def preprocess(
        self,
553
        prompt: PromptType | DictPrompt | TokPrompt,
554
        tokenization_kwargs: dict[str, Any] | None = None,
555
        *,
556
        mm_uuids: MultiModalUUIDDict | None = None,
557
558
    ) -> ProcessorInputs:
        """Preprocess the input prompt."""
559
        res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
560
561
562
563
564
565
566
567
568

        if self.mm_processor_cache and self.mm_cache_stats is not None:
            delta = self.mm_processor_cache.make_stats(delta=True)
            self.mm_cache_stats.requests += 1
            self.mm_cache_stats.queries += delta.total
            self.mm_cache_stats.hits += delta.hits

        return res

569
    def stat_mm_cache(self) -> MultiModalCacheStats | None:
570
571
572
573
574
575
576
577
578
        mm_cache_stats = self.mm_cache_stats
        if mm_cache_stats is None:
            return None

        self.mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def clear_mm_cache(self) -> None:
579
580
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()
581
582
583

        if self.mm_cache_stats is not None:
            self.mm_cache_stats.reset = True