"vscode:/vscode.git/clone" did not exist on "7de18d541b0da661685d481d7306cbe5e9f7960b"
preprocess.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any
6
7
8

from typing_extensions import assert_never

9
from vllm.config import ModelConfig, ObservabilityConfig
10
from vllm.inputs.parse import split_enc_dec_prompt
11
from vllm.logger import init_logger
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
from vllm.multimodal.cache import BaseMultiModalProcessorCache
14
15
16
17
18
19
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalInputs,
    MultiModalUUIDDict,
)
20
from vllm.multimodal.processing import BaseMultiModalProcessor
21
from vllm.renderers import renderer_from_config
22
from vllm.tokenizers import TokenizerLike
23
from vllm.utils.jsontree import json_iter_leaves
24
from vllm.v1.metrics.stats import MultiModalCacheStats
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from .data import (
    DecoderOnlyInputs,
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    SingletonPrompt,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
    embeds_inputs,
    token_inputs,
)
41
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
42
43
44
45
46
47
48

logger = init_logger(__name__)


class InputPreprocessor:
    def __init__(
        self,
49
        model_config: ModelConfig,
50
        observability_config: ObservabilityConfig | None = None,
51
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
52
        mm_processor_cache: BaseMultiModalProcessorCache | None = None,
53
54
55
    ) -> None:
        super().__init__()

56
        self.model_config = model_config
57
        self.observability_config = observability_config
58
        self.renderer = renderer_from_config(model_config)
59
        self.mm_registry = mm_registry
60
        self.mm_processor_cache = mm_processor_cache
61

62
63
        self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None

64
65
66
    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer
67

68
69
    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()
70

71
    def get_bos_token_id(self) -> int | None:
72
        if self.tokenizer is None:
73
            logger.warning_once(
74
75
                "Using None for BOS token id because tokenizer is not initialized"
            )
76
77
            return None

78
        return self.tokenizer.bos_token_id
79

80
    def get_eos_token_id(self) -> int | None:
81
        if self.tokenizer is None:
82
            logger.warning_once(
83
84
                "Using None for EOS token id because tokenizer is not initialized"
            )
85
86
            return None

87
        return self.tokenizer.eos_token_id
88

89
    def get_decoder_start_token_id(self) -> int:
90
        """
91
        Obtain the decoder start token id employed by an encoder/decoder
92
        model. Raises an error if it is not available.
93
        """
94
95
96
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )
97

98
        if dec_start_token_id is None:
99
100
101
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
102
103
                "available."
            )
104
105
            dec_start_token_id = self.get_bos_token_id()

106
107
        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")
108

109
        return dec_start_token_id
110

111
    def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
112
113
114
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

115
116
117
118
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
119
120
121
122
123
124
125
126
127
128
129

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """
        decoder_start_token_id = self.get_decoder_start_token_id()

130
131
132
133
        if (
            len(decoder_input_ids) == 0
            or decoder_input_ids[0] != decoder_start_token_id
        ):
134
135
136
137
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

138
139
    def _get_tokenization_kw(
        self,
140
        overrides: dict[str, Any] | None = None,
141
142
143
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

144
        if self.model_config.is_encoder_decoder:
145
146
147
148
149
150
151
152
153
154
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

155
156
157
    def _tokenize_prompt(
        self,
        prompt: str,
158
        tokenization_kwargs: dict[str, Any] | None = None,
159
    ) -> list[int]:
160
161
162
163
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
164
        tokenizer = self.get_tokenizer()
165
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
166

167
        encoder_config = self.model_config.encoder_config
168

169
        if encoder_config and encoder_config.get("do_lower_case", False):
170
171
            prompt = prompt.lower()

172
        return tokenizer.encode(prompt, **tokenization_kwargs)
173

174
175
176
    def _get_mm_processor(self) -> BaseMultiModalProcessor:
        if not hasattr(self, "_mm_processor"):
            self._mm_processor = self.mm_registry.create_processor(
177
                self.model_config,
178
                self.observability_config,
179
                tokenizer=self.tokenizer,
180
181
182
183
                cache=self.mm_processor_cache,
            )

        return self._mm_processor
184

185
186
    def _process_multimodal(
        self,
187
        prompt: str | list[int],
188
        mm_data: MultiModalDataDict,
189
190
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None = None,
191
        *,
192
        mm_uuids: MultiModalUUIDDict | None = None,
193
    ) -> MultiModalInputs:
194
195
196
197
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
198
        mm_processor = self._get_mm_processor()
199

200
201
202
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

203
        mm_items = mm_processor.info.parse_mm_data(mm_data)
204
        mm_input = mm_processor.apply(
205
            prompt,
206
            mm_items,
207
208
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
209
            mm_uuids=mm_uuids,
210
        )
211
212
213
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
214
215
216
217
        contains_only_strings = all(
            isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
        )
        if not contains_only_strings:
218
219
220
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
221
222
                "MultiModalProcessor.apply method."
            )
223
224

        return mm_input
225

226
227
228
229
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
230
        if not self.model_config.enable_prompt_embeds:
231
232
233
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )
234
235

        prompt_embeds = parsed_content["prompt_embeds"]
236

237
238
239
240
241
242
243
244
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
245
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
246

247
248
249
250
251
        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

252
253
254
        return embeds_inputs(
            prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
        )
255

256
    def _truncate_inputs(
257
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
258
259
260
261
262
263
    ) -> list[int]:
        if (
            not tokenization_kwargs
            or "truncation" not in tokenization_kwargs
            or self.tokenizer is None
        ):
264
265
266
267
268
269
270
271
272
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

273
274
275
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
276
        tokenization_kwargs: dict[str, Any] | None = None,
277
        *,
278
279
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
280
        prompt_token_ids = self._truncate_inputs(
281
282
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )
283

284
        inputs: TokenInputs | MultiModalInputs
285
        if multi_modal_data := parsed_content.get("multi_modal_data"):
286
287
            inputs = self._process_multimodal(
                prompt_token_ids,
288
                multi_modal_data,
289
                parsed_content.get("mm_processor_kwargs") or {},
290
                tokenization_kwargs=tokenization_kwargs,
291
                mm_uuids=mm_uuids,
292
            )
293
        else:
294
            inputs = token_inputs(prompt_token_ids)
295
296
297
298
299
300
301
302
303

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
304
        tokenization_kwargs: dict[str, Any] | None = None,
305
        *,
306
307
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
308
309
        prompt_text = parsed_content["prompt"]

310
        inputs: TokenInputs | MultiModalInputs
311
        if multi_modal_data := parsed_content.get("multi_modal_data"):
312
313
            inputs = self._process_multimodal(
                prompt_text,
314
                multi_modal_data,
315
                parsed_content.get("mm_processor_kwargs") or {},
316
                tokenization_kwargs=tokenization_kwargs,
317
                mm_uuids=mm_uuids,
318
319
320
321
322
323
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
324
            inputs = token_inputs(prompt_token_ids)
325
326
327
328
329

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
330

331
    def _prompt_to_llm_inputs(
332
        self,
333
        prompt: SingletonPrompt,
334
        tokenization_kwargs: dict[str, Any] | None = None,
335
        *,
336
        mm_uuids: MultiModalUUIDDict | None = None,
337
    ) -> SingletonInputs:
338
339
        """
        Extract the singleton inputs from a prompt.
340
341
342

        Arguments:

343
        * prompt: single encoder or decoder input prompt
344
345
346

        Returns:

347
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
348
        """
349
        parsed = parse_singleton_prompt(prompt)
350
351

        if parsed["type"] == "embeds":
352
353
354
355
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
356
                mm_uuids=mm_uuids,
357
            )
358
359
360
361
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
362
                mm_uuids=mm_uuids,
363
364
365
366
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
367
                tokenization_kwargs=tokenization_kwargs,
368
                mm_uuids=mm_uuids,
369
            )
370

371
372
        assert_never(parsed)

373
    def _validate_enc_inputs(
374
        self,
375
376
377
        inputs: SingletonInputs,
    ) -> TokenInputs | MultiModalEncDecInputs:
        if inputs["type"] == "embeds":
378
379
380
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
381

382
383
384
385
        if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
            raise RuntimeError(
                "You should register an encoder-decoder "
                "multi-modal processor for encoder-decoder models."
386
            )
387

388
        return inputs  # type: ignore[return-value]
389

390
    def _validate_dec_inputs(
391
        self,
392
393
394
        inputs: SingletonInputs,
    ) -> TokenInputs | MultiModalInputs:
        if inputs["type"] == "embeds":
395
396
397
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )
398

399
        return inputs
400

401
402
403
404
405
406
407
    def _build_enc_dec_inputs(
        self,
        encoder_inputs: SingletonInputs,
        decoder_inputs: SingletonInputs | None = None,
    ) -> EncoderDecoderInputs:
        if decoder_inputs is None:
            decoder_inputs = encoder_inputs
408

409
410
        enc_inputs = self._validate_enc_inputs(encoder_inputs)
        dec_inputs = self._validate_dec_inputs(decoder_inputs)
411

412
413
        enc_inputs_new: TokenInputs | MultiModalEncDecInputs
        dec_inputs_new: TokenInputs | MultiModalInputs
414

415
416
417
        if enc_inputs["type"] == "multimodal":
            enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
            dec_inputs_new = MultiModalInputs(
418
                type="multimodal",
419
420
421
422
                prompt_token_ids=dec_inputs["prompt_token_ids"],
                mm_kwargs=enc_inputs["mm_kwargs"],
                mm_hashes=enc_inputs["mm_hashes"],
                mm_placeholders=enc_inputs["mm_placeholders"],
423
            )
424
425
426
        elif enc_inputs["type"] == "token":
            enc_inputs_new = token_inputs(prompt_token_ids=[])
            dec_inputs_new = dec_inputs
427
        else:
428
429
430
431
432
433
434
            assert_never(enc_inputs)

        dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
            dec_inputs_new["prompt_token_ids"]
        )
        if cache_salt := enc_inputs.get("cache_salt"):
            dec_inputs_new["cache_salt"] = cache_salt
435

436
        return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
437

438
439
    def _process_encoder_decoder_prompt(
        self,
440
        prompt: PromptType,
441
        tokenization_kwargs: dict[str, Any] | None = None,
442
        *,
443
        mm_uuids: MultiModalUUIDDict | None = None,
444
    ) -> EncoderDecoderInputs:
445
        """
446
        For encoder/decoder models only:
447
448
449
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
468

469
470
        Arguments:

471
        * prompt: an input prompt
472
473
474

        Returns:

475
476
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
477
        """
478
479
480
481
482
        encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt)

        return self._build_enc_dec_inputs(
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
483
                tokenization_kwargs=tokenization_kwargs,
484
                mm_uuids=mm_uuids,
485
486
487
488
489
490
491
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
492
                )
493
494
            ),
        )
495
496
497

    def _process_decoder_only_prompt(
        self,
498
        prompt: SingletonPrompt,
499
        tokenization_kwargs: dict[str, Any] | None = None,
500
        *,
501
        mm_uuids: MultiModalUUIDDict | None = None,
502
    ) -> DecoderOnlyInputs:
503
        """
504
        For decoder-only models:
505
506
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
507
508
509

        Arguments:

510
        * prompt: input prompt
511
512
513

        Returns:

514
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
515
        """
516
        return self._prompt_to_llm_inputs(
517
            prompt,
518
            tokenization_kwargs=tokenization_kwargs,
519
            mm_uuids=mm_uuids,
520
521
        )

522
    def _preprocess(
523
        self,
524
        prompt: PromptType,
525
        tokenization_kwargs: dict[str, Any] | None = None,
526
        *,
527
        mm_uuids: MultiModalUUIDDict | None = None,
528
    ) -> ProcessorInputs:
529
        if self.model_config.is_encoder_decoder:
530
            # Encoder-decoder model requires special mapping of
531
            # input prompts to encoder & decoder.
532
            return self._process_encoder_decoder_prompt(
533
534
                prompt,
                tokenization_kwargs,
535
                mm_uuids=mm_uuids,
536
            )
537

538
        if is_explicit_encoder_decoder_prompt(prompt):
539
540
541
            raise ValueError(
                "Cannot pass encoder-decoder prompt to decoder-only models"
            )
542
543

        return self._process_decoder_only_prompt(
544
            prompt,
545
            tokenization_kwargs=tokenization_kwargs,
546
            mm_uuids=mm_uuids,
547
548
        )

549
550
551
    def preprocess(
        self,
        prompt: PromptType,
552
        tokenization_kwargs: dict[str, Any] | None = None,
553
        *,
554
        mm_uuids: MultiModalUUIDDict | None = None,
555
556
    ) -> ProcessorInputs:
        """Preprocess the input prompt."""
557
        res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
558
559
560
561
562
563
564
565
566

        if self.mm_processor_cache and self.mm_cache_stats is not None:
            delta = self.mm_processor_cache.make_stats(delta=True)
            self.mm_cache_stats.requests += 1
            self.mm_cache_stats.queries += delta.total
            self.mm_cache_stats.hits += delta.hits

        return res

567
    def stat_mm_cache(self) -> MultiModalCacheStats | None:
568
569
570
571
572
573
574
575
576
        mm_cache_stats = self.mm_cache_stats
        if mm_cache_stats is None:
            return None

        self.mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def clear_mm_cache(self) -> None:
577
578
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()
579
580
581

        if self.mm_cache_stats is not None:
            self.mm_cache_stats.reset = True