preprocess.py 35.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Mapping
6
from typing import Any, Optional, Union, cast
7
8
9
10
11
12

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
from vllm.multimodal.cache import BaseMultiModalProcessorCache
15
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
                                    MultiModalInputs, MultiModalUUIDDict)
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
19

20
21
22
23
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
                   EncoderDecoderInputs, ProcessorInputs, PromptType,
                   SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
                   TokensPrompt, embeds_inputs, token_inputs)
24
25
26
27
28
29
30
31
32
33
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
34
        tokenizer: Optional[TokenizerGroup],
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
        mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
37
38
39
40
41
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
42
        self.mm_registry = mm_registry
43
        self.mm_processor_cache = mm_processor_cache
44

45
    def get_tokenizer_group(self) -> TokenizerGroup:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
73
        """
74
75
76
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
77
        """
78

79
        if not self.model_config.is_encoder_decoder:
80
81
82
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
83
84
            return None

85
        if self.model_config is None or self.model_config.hf_config is None:
86
87
88
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
89
90
91
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
92
                                     "decoder_start_token_id", None)
93
        if dec_start_token_id is None:
94
95
96
97
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
98
99
100
101
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

102
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
103
        """
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
124
        other models may have different or more
125
126
127
128
129
130
131
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
132
        """
133
134
135
136
137
138
139

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
140
141
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
142
143
144
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

145
146
147
148
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

167
168
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
169
170
171
172
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

190
191
192
193
    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
194
        tokenization_kwargs: Optional[dict[str, Any]] = None,
195
    ) -> list[int]:
196
197
198
199
200
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
201
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
202

203
        encoder_config = self.model_config.encoder_config
204

205
        if encoder_config and encoder_config.get("do_lower_case", False):
206
207
            prompt = prompt.lower()

208
209
210
211
212
        if self.model_config.tokenizer_mode == "cpm":
                return [tokenizer.bos_id] + tokenizer.encode(prompt)
        else:
            return tokenizer.encode(prompt=prompt,
                                    lora_request=lora_request,
zhuwenwen's avatar
zhuwenwen committed
213
                                    **tokenization_kwargs)
214
215
216
217
218

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
219
        tokenization_kwargs: Optional[dict[str, Any]] = None,
220
    ) -> list[int]:
221
222
223
224
        """
        Async version of
        [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
        """
225
        tokenizer = self.get_tokenizer_group()
226
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
227
228
229
230

        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def _get_mm_tokenizer(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return tokenizer_group.get_lora_tokenizer(lora_request)

    async def _get_mm_tokenizer_async(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return await tokenizer_group.get_lora_tokenizer_async(lora_request)
255

256
257
    def _process_multimodal(
        self,
258
        prompt: Union[str, list[int]],
259
260
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
261
262
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
263
        *,
264
        mm_uuids: Optional[MultiModalUUIDDict] = None,
265
    ) -> MultiModalInputs:
266
267
268
269
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
270
        tokenizer = self._get_mm_tokenizer(lora_request)
271

272
273
274
275
276
        mm_processor = self.mm_registry.create_processor(
            self.model_config,
            tokenizer=tokenizer,
            cache=self.mm_processor_cache,
        )
277
278
279
280

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

281
        mm_input = mm_processor.apply(
282
283
284
285
            prompt,
            mm_data,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
286
            mm_uuids=mm_uuids,
287
        )
288
289
290
291
292
293
294
295
296
297
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        if not contains_only_strings(mm_hashes):
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method.")

        return mm_input
298
299
300

    async def _process_multimodal_async(
        self,
301
        prompt: Union[str, list[int]],
302
303
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
304
305
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
306
        *,
307
        mm_uuids: Optional[MultiModalUUIDDict] = None,
308
    ) -> MultiModalInputs:
309
310
311
312
        """
        Async version of
        [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
        """
313
        tokenizer = await self._get_mm_tokenizer_async(lora_request)
314

315
316
317
318
319
320
        mm_processor = self.mm_registry.create_processor(
            self.model_config,
            tokenizer=tokenizer,
            cache=self.mm_processor_cache,
        )

321
322
323
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

324
        mm_input = mm_processor.apply(
325
326
327
328
            prompt,
            mm_data,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
329
            mm_uuids=mm_uuids,
330
        )
331
332
333
334
335
336
337
338
339
340
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        if not contains_only_strings(mm_hashes):
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method.")

        return mm_input
341

342
    def _process_embeds(
343
        self,
344
345
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
346
347
348
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
349

350
        prompt_embeds = parsed_content["prompt_embeds"]
351

352
353
354
355
356
357
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)
358

359
360
361
        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")
362

363
364
        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))
365

366
367
368
369
370
371
    async def _process_embeds_async(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        return self._process_embeds(parsed_content)

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def _truncate_inputs(
            self,
            inputs: list[int],
            tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:

        if not tokenization_kwargs or "truncation" not in \
                tokenization_kwargs or self.tokenizer is None:
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

388
389
390
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
391
        tokenization_kwargs: Optional[dict[str, Any]] = None,
392
        lora_request: Optional[LoRARequest] = None,
393
        *,
394
        mm_uuids: Optional[MultiModalUUIDDict] = None,
395
    ) -> Union[TokenInputs, MultiModalInputs]:
396
397
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs)
398
399
400
401
402
403
404

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
405
                tokenization_kwargs=tokenization_kwargs,
406
                lora_request=lora_request,
407
                mm_uuids=mm_uuids,
408
            )
409
        else:
410
            inputs = token_inputs(prompt_token_ids=prompt_token_ids)
411

412
413
414
415
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
416

417
418
419
    async def _process_tokens_async(
        self,
        parsed_content: TokensPrompt,
420
        tokenization_kwargs: Optional[dict[str, Any]] = None,
421
        lora_request: Optional[LoRARequest] = None,
422
        *,
423
        mm_uuids: Optional[MultiModalUUIDDict] = None,
424
    ) -> Union[TokenInputs, MultiModalInputs]:
425
426
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs)
427
428
429
430
431
432
433

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
434
                tokenization_kwargs=tokenization_kwargs,
435
                lora_request=lora_request,
436
                mm_uuids=mm_uuids,
437
438
            )
        else:
439
            inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
440

441
442
443
444
445
446
447
448
449
450
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
451
        *,
452
        mm_uuids: Optional[MultiModalUUIDDict] = None,
453
454
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]
455

456
457
458
459
460
461
        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
462
                tokenization_kwargs=tokenization_kwargs,
463
                lora_request=lora_request,
464
                mm_uuids=mm_uuids,
465
466
            )
        else:
467
            prompt_token_ids = self._tokenize_prompt(
468
                prompt_text,
469
                lora_request=lora_request,
470
                tokenization_kwargs=tokenization_kwargs,
471
            )
472
            inputs = token_inputs(
473
474
475
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )
476

477
478
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt
479

480
        return inputs
481

482
    async def _process_text_async(
483
        self,
484
485
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
486
        lora_request: Optional[LoRARequest] = None,
487
        *,
488
        mm_uuids: Optional[MultiModalUUIDDict] = None,
489
490
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]
491

492
493
494
495
496
497
        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
498
                tokenization_kwargs=tokenization_kwargs,
499
                lora_request=lora_request,
500
                mm_uuids=mm_uuids,
501
502
            )
        else:
503
            prompt_token_ids = await self._tokenize_prompt_async(
504
                prompt_text,
505
                lora_request=lora_request,
506
                tokenization_kwargs=tokenization_kwargs,
507
            )
508
            inputs = token_inputs(
509
510
511
512
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

513
514
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt
515

516
        return inputs
517
518

    def _prompt_to_llm_inputs(
519
        self,
520
        prompt: SingletonPrompt,
521
        tokenization_kwargs: Optional[dict[str, Any]] = None,
522
        lora_request: Optional[LoRARequest] = None,
523
        *,
524
        mm_uuids: Optional[MultiModalUUIDDict] = None,
525
    ) -> SingletonInputs:
526
527
528
        """
        Extract the singleton inputs from a prompt.

529
530
        Arguments:

531
        * prompt: single encoder or decoder input prompt
532
533
534
535
        * lora_request: this is only valid for decoder prompts

        Returns:

536
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
537
        """
538
        parsed = parse_singleton_prompt(prompt)
539
540

        if parsed["type"] == "embeds":
541
542
543
544
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
545
                lora_request=lora_request,
546
                mm_uuids=mm_uuids,
547
            )
548
549
550
551
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
552
                lora_request=lora_request,
553
                mm_uuids=mm_uuids,
554
555
556
557
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
558
                tokenization_kwargs=tokenization_kwargs,
559
                lora_request=lora_request,
560
                mm_uuids=mm_uuids,
561
            )
562

563
564
        assert_never(parsed)

565
    async def _prompt_to_llm_inputs_async(
566
        self,
567
        prompt: SingletonPrompt,
568
        tokenization_kwargs: Optional[dict[str, Any]] = None,
569
        lora_request: Optional[LoRARequest] = None,
570
        *,
571
        mm_uuids: Optional[MultiModalUUIDDict] = None,
572
    ) -> SingletonInputs:
573
574
575
576
        """
        Async version of
        [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
        """
577
        parsed = parse_singleton_prompt(prompt)
578

579
        if parsed["type"] == "embeds":
580
581
582
583
            return await self._process_embeds_async(parsed["content"])
        if parsed["type"] == "tokens":
            return await self._process_tokens_async(
                parsed["content"],
584
                lora_request=lora_request,
585
                mm_uuids=mm_uuids,
586
            )
587
588
589
590
        if parsed["type"] == "text":
            return await self._process_text_async(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
591
                lora_request=lora_request,
592
                mm_uuids=mm_uuids,
593
594
595
596
            )
        if parsed["type"] == "str":
            return await self._process_text_async(
                TextPrompt(prompt=parsed["content"]),
597
                tokenization_kwargs=tokenization_kwargs,
598
                lora_request=lora_request,
599
                mm_uuids=mm_uuids,
600
            )
601

602
        assert_never(parsed)
603
604
605

    def _build_enc_dec_llm_inputs(
        self,
606
607
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
608
    ) -> EncoderDecoderInputs:
609
610
611
612
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
613

614
615
616
617
618
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
619
620

        if decoder_inputs is None:
621
622
623
624
625
626
627
628
629
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
630
            decoder_inputs = token_inputs(dec_token_ids)
631
        else:
632
633
634
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
635
636
637
638

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
639

640
        return EncoderDecoderInputs(
641
642
            encoder=encoder_inputs,
            decoder=decoder_inputs,
643
644
        )

645
    def _split_enc_dec_mm_inputs(
646
        self,
647
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
648
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
649
    ) -> tuple[SingletonInputs, SingletonInputs]:
650
651
652
653
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

669
670
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
671
672
673
674
675
676
677

        if inputs["type"] == "multimodal":  # Multimodal data inputs
            if not ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs):
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
678
            inputs = cast(MultiModalEncDecInputs, inputs)
679

680
681
682
683
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
684

685
686
687
688
689
690
691
692
693
694
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt=decoder_prompt_inputs.get("prompt", ""),
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
695
696
                decoder_inputs["cache_salt"] = cache_salt

697
        elif inputs["type"] == "token":  # Text-only inputs
698
699
700
701
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
702

703
704
        return encoder_inputs, decoder_inputs

705
706
    def _process_encoder_decoder_prompt(
        self,
707
        prompt: PromptType,
708
        tokenization_kwargs: Optional[dict[str, Any]] = None,
709
        *,
710
        mm_uuids: Optional[MultiModalUUIDDict] = None,
711
    ) -> EncoderDecoderInputs:
712
        """
713
        For encoder/decoder models only:
714
715
716
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
735

736
737
        Arguments:

738
        * prompt: an input prompt
739
740
741

        Returns:

742
743
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
744
        """
745
746
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
747

748
        if is_explicit_encoder_decoder_prompt(prompt):
749
            encoder_inputs = self._prompt_to_llm_inputs(
750
751
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
752
                mm_uuids=mm_uuids,
753
            )
754
            if (decoder_input := prompt["decoder_prompt"]) is None:
755
                decoder_inputs = None
756
            else:
757
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
758
759
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
760
            if self.model_config.is_multimodal_model:
761
                encoder_inputs, decoder_inputs = (
762
763
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
764
        else:
765
766
767
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
768
                mm_uuids=mm_uuids,
769
            )
770
            if self.model_config.is_multimodal_model:
771
772
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
773
                    self._split_enc_dec_mm_inputs(inputs))
774
775
776
            else:
                encoder_inputs = inputs
                decoder_inputs = None
777
778

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
779
780
781

    async def _process_encoder_decoder_prompt_async(
        self,
782
        prompt: PromptType,
783
        tokenization_kwargs: Optional[dict[str, Any]] = None,
784
        *,
785
        mm_uuids: Optional[MultiModalUUIDDict] = None,
786
    ) -> EncoderDecoderInputs:
787
788
789
790
        """
        Async version of
        [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
        """
791
792
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
793

794
        if is_explicit_encoder_decoder_prompt(prompt):
795
            encoder_task = self._prompt_to_llm_inputs_async(
796
797
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
798
                mm_uuids=mm_uuids,
799
            )
800

801
            if (decoder_input := prompt["decoder_prompt"]) is None:
802
803
                encoder_inputs = await encoder_task
                decoder_inputs = None
804
            else:
805
806
807
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
808
                    mm_uuids=mm_uuids,
809
                )
810

811
                encoder_inputs, decoder_inputs = await asyncio.gather(
812
                    encoder_task, decoder_task)
813
814
815

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
816
            if self.model_config.is_multimodal_model:
817
                encoder_inputs, decoder_inputs = (
818
819
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
820
        else:
821
822
823
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
824
                mm_uuids=mm_uuids,
825
            )
826
            if self.model_config.is_multimodal_model:
827
828
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
829
                    self._split_enc_dec_mm_inputs(inputs))
830
831
832
            else:
                encoder_inputs = inputs
                decoder_inputs = None
833
834

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
835
836
837

    def _build_decoder_only_llm_inputs(
        self,
838
        prompt_inputs: DecoderOnlyInputs,
839
    ) -> DecoderOnlyInputs:
840
841
842
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
843

844
        return prompt_inputs
845
846
847

    def _process_decoder_only_prompt(
        self,
848
        prompt: SingletonPrompt,
849
        tokenization_kwargs: Optional[dict[str, Any]] = None,
850
        lora_request: Optional[LoRARequest] = None,
851
        *,
852
        mm_uuids: Optional[MultiModalUUIDDict] = None,
853
    ) -> DecoderOnlyInputs:
854
        """
855
        For decoder-only models:
856
857
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
858
859
860

        Arguments:

861
        * prompt: input prompt
862
863
864
865
        * lora_request

        Returns:

866
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
867
        """
868

869
        prompt_comps = self._prompt_to_llm_inputs(
870
            prompt,
871
            tokenization_kwargs=tokenization_kwargs,
872
            lora_request=lora_request,
873
            mm_uuids=mm_uuids,
874
875
        )

876
        return self._build_decoder_only_llm_inputs(prompt_comps)
877
878
879

    async def _process_decoder_only_prompt_async(
        self,
880
        prompt: SingletonPrompt,
881
        tokenization_kwargs: Optional[dict[str, Any]] = None,
882
        lora_request: Optional[LoRARequest] = None,
883
        *,
884
        mm_uuids: Optional[MultiModalUUIDDict] = None,
885
    ) -> DecoderOnlyInputs:
886
887
888
889
        """
        Async version of
        [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
        """
890
        prompt_comps = await self._prompt_to_llm_inputs_async(
891
            prompt,
892
            tokenization_kwargs=tokenization_kwargs,
893
            lora_request=lora_request,
894
            mm_uuids=mm_uuids,
895
896
        )

897
        return self._build_decoder_only_llm_inputs(prompt_comps)
898
899
900

    def preprocess(
        self,
901
        prompt: PromptType,
902
        tokenization_kwargs: Optional[dict[str, Any]] = None,
903
        lora_request: Optional[LoRARequest] = None,
904
        *,
905
        mm_uuids: Optional[MultiModalUUIDDict] = None,
906
    ) -> ProcessorInputs:
907
        """Preprocess the input prompt."""
908
        if self.model_config.is_encoder_decoder:
909
            # Encoder-decoder model requires special mapping of
910
            # input prompts to encoder & decoder.
911
            return self._process_encoder_decoder_prompt(
912
913
                prompt,
                tokenization_kwargs,
914
                mm_uuids=mm_uuids,
915
            )
916

917
        if is_explicit_encoder_decoder_prompt(prompt):
918
919
920
921
922
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
923
            prompt,
924
            tokenization_kwargs=tokenization_kwargs,
925
            lora_request=lora_request,
926
            mm_uuids=mm_uuids,
927
928
929
930
        )

    async def preprocess_async(
        self,
931
        prompt: PromptType,
932
        tokenization_kwargs: Optional[dict[str, Any]] = None,
933
        lora_request: Optional[LoRARequest] = None,
934
        *,
935
        mm_uuids: Optional[MultiModalUUIDDict] = None,
936
    ) -> ProcessorInputs:
937
938
939
940
        """
        Async version of
        [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
        """
941
        if self.model_config.is_encoder_decoder:
942
            # Encoder-decoder model requires special mapping of
943
944
945
946
            # input prompts to encoder & decoder.
            return await self._process_encoder_decoder_prompt_async(
                prompt,
                tokenization_kwargs,
947
                mm_uuids=mm_uuids,
948
            )
949

950
        if is_explicit_encoder_decoder_prompt(prompt):
951
952
953
954
955
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
956
            prompt,
957
            tokenization_kwargs=tokenization_kwargs,
958
            lora_request=lora_request,
959
            mm_uuids=mm_uuids,
960
        )
961
962
963
964

    def clear_cache(self) -> None:
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()
965
966
967
968
969
970
971
972
973
974
975
976


# Helper function to validate that a nested dictionary contains
# only strings or list of strings as the leaf values.
def contains_only_strings(obj: object):
    if isinstance(obj, str):
        return True
    if isinstance(obj, list):
        return all(isinstance(x, str) for x in obj)
    if isinstance(obj, dict):
        return all(contains_only_strings(v) for v in obj.values())
    return False