preprocess.py 36.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Mapping
6
from typing import Any, Optional, Union, cast
7
8
9
10
11
12

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
from vllm.multimodal.cache import BaseMultiModalProcessorCache
15
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
                                    MultiModalInputs, MultiModalUUIDDict)
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
19

20
21
22
23
24
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
                   EncoderDecoderInputs, ProcessorInputs, PromptType,
                   SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
                   TokensPrompt, embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
25
26
27
28
29
30
31
32
33

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
34
        tokenizer: Optional[TokenizerGroup],
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
        mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
37
38
39
40
41
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
42
        self.mm_registry = mm_registry
43
        self.mm_processor_cache = mm_processor_cache
44

45
    def get_tokenizer_group(self) -> TokenizerGroup:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
73
        """
74
75
76
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
77
        """
78

79
        if not self.model_config.is_encoder_decoder:
80
81
82
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
83
84
            return None

85
        if self.model_config is None or self.model_config.hf_config is None:
86
87
88
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
89
90
91
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
92
                                     "decoder_start_token_id", None)
93
        if dec_start_token_id is None:
94
95
96
97
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
98
99
100
101
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

102
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
103
        """
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
124
        other models may have different or more
125
126
127
128
129
130
131
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
132
        """
133
134
135
136
137
138
139

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
140
141
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
142
143
144
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

145
146
147
148
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

167
168
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
169
170
171
172
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

190
191
192
193
    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
194
        tokenization_kwargs: Optional[dict[str, Any]] = None,
195
    ) -> list[int]:
196
197
198
199
200
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
201
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
202

203
        encoder_config = self.model_config.encoder_config
204

205
        if encoder_config and encoder_config.get("do_lower_case", False):
206
207
            prompt = prompt.lower()

208
        return tokenizer.encode(prompt=prompt,
209
                                lora_request=lora_request,
210
                                **tokenization_kwargs)
211
212
213
214
215

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
216
        tokenization_kwargs: Optional[dict[str, Any]] = None,
217
    ) -> list[int]:
218
219
220
221
        """
        Async version of
        [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
        """
222
        tokenizer = self.get_tokenizer_group()
223
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
224
225
226
227

        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    def _get_mm_tokenizer(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return tokenizer_group.get_lora_tokenizer(lora_request)

    async def _get_mm_tokenizer_async(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return await tokenizer_group.get_lora_tokenizer_async(lora_request)

253
254
    def _process_multimodal(
        self,
255
        prompt: Union[str, list[int]],
256
257
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
258
259
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
260
        *,
261
262
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
263
    ) -> MultiModalInputs:
264
265
266
267
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
268
        tokenizer = self._get_mm_tokenizer(lora_request)
269

270
271
272
273
274
        mm_processor = self.mm_registry.create_processor(
            self.model_config,
            tokenizer=tokenizer,
            cache=self.mm_processor_cache,
        )
275
276
277
278

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

279
        mm_input = mm_processor.apply(
280
281
282
283
284
285
            prompt,
            mm_data,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_hash_overrides=mm_hash_overrides,
        )
286
287
288
289
290
291
292
293
294
295
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        if not contains_only_strings(mm_hashes):
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method.")

        return mm_input
296
297
298

    async def _process_multimodal_async(
        self,
299
        prompt: Union[str, list[int]],
300
301
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
302
303
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
304
        *,
305
306
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
307
    ) -> MultiModalInputs:
308
309
310
311
        """
        Async version of
        [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
        """
312
        tokenizer = await self._get_mm_tokenizer_async(lora_request)
313

314
315
316
317
318
319
        mm_processor = self.mm_registry.create_processor(
            self.model_config,
            tokenizer=tokenizer,
            cache=self.mm_processor_cache,
        )

320
321
322
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

323
        mm_input = mm_processor.apply(
324
325
326
327
328
329
            prompt,
            mm_data,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_hash_overrides=mm_hash_overrides,
        )
330
331
332
333
334
335
336
337
338
339
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        if not contains_only_strings(mm_hashes):
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method.")

        return mm_input
340

341
342
343
344
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
345
346
347
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
348
349

        prompt_embeds = parsed_content["prompt_embeds"]
350

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")

        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))

    async def _process_embeds_async(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        return self._process_embeds(parsed_content)

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def _truncate_inputs(
            self,
            inputs: list[int],
            tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:

        if not tokenization_kwargs or "truncation" not in \
                tokenization_kwargs or self.tokenizer is None:
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

387
388
389
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
390
        tokenization_kwargs: Optional[dict[str, Any]] = None,
391
        lora_request: Optional[LoRARequest] = None,
392
        *,
393
394
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
395
    ) -> Union[TokenInputs, MultiModalInputs]:
396
397
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs)
398
399
400
401
402
403
404

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
405
                tokenization_kwargs=tokenization_kwargs,
406
                lora_request=lora_request,
407
                mm_hash_overrides=mm_hash_overrides,
408
            )
409
        else:
410
            inputs = token_inputs(prompt_token_ids=prompt_token_ids)
411
412
413
414
415
416
417
418
419

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    async def _process_tokens_async(
        self,
        parsed_content: TokensPrompt,
420
        tokenization_kwargs: Optional[dict[str, Any]] = None,
421
        lora_request: Optional[LoRARequest] = None,
422
        *,
423
424
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
425
    ) -> Union[TokenInputs, MultiModalInputs]:
426
427
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs)
428
429
430
431
432
433
434

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
435
                tokenization_kwargs=tokenization_kwargs,
436
                lora_request=lora_request,
437
                mm_hash_overrides=mm_hash_overrides,
438
439
            )
        else:
440
            inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
441
442
443
444
445
446
447
448
449
450
451

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
452
        *,
453
454
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
455
456
457
458
459
460
461
462
463
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
464
                tokenization_kwargs=tokenization_kwargs,
465
                lora_request=lora_request,
466
                mm_hash_overrides=mm_hash_overrides,
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
483

484
485
486
487
488
    async def _process_text_async(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
489
        *,
490
491
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
492
493
494
495
496
497
498
499
500
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
501
                tokenization_kwargs=tokenization_kwargs,
502
                lora_request=lora_request,
503
                mm_hash_overrides=mm_hash_overrides,
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
            )
        else:
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
520

521
    def _prompt_to_llm_inputs(
522
        self,
523
        prompt: SingletonPrompt,
524
        tokenization_kwargs: Optional[dict[str, Any]] = None,
525
        lora_request: Optional[LoRARequest] = None,
526
        *,
527
528
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
529
    ) -> SingletonInputs:
530
531
        """
        Extract the singleton inputs from a prompt.
532
533
534

        Arguments:

535
        * prompt: single encoder or decoder input prompt
536
537
538
539
        * lora_request: this is only valid for decoder prompts

        Returns:

540
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
541
        """
542
        parsed = parse_singleton_prompt(prompt)
543
544

        if parsed["type"] == "embeds":
545
546
547
548
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
549
                lora_request=lora_request,
550
                mm_hash_overrides=mm_hash_overrides,
551
            )
552
553
554
555
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
556
                lora_request=lora_request,
557
                mm_hash_overrides=mm_hash_overrides,
558
559
560
561
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
562
                tokenization_kwargs=tokenization_kwargs,
563
                lora_request=lora_request,
564
                mm_hash_overrides=mm_hash_overrides,
565
            )
566

567
568
        assert_never(parsed)

569
    async def _prompt_to_llm_inputs_async(
570
        self,
571
        prompt: SingletonPrompt,
572
        tokenization_kwargs: Optional[dict[str, Any]] = None,
573
        lora_request: Optional[LoRARequest] = None,
574
        *,
575
576
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
577
    ) -> SingletonInputs:
578
579
580
581
        """
        Async version of
        [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
        """
582
        parsed = parse_singleton_prompt(prompt)
583

584
        if parsed["type"] == "embeds":
585
586
587
588
            return await self._process_embeds_async(parsed["content"])
        if parsed["type"] == "tokens":
            return await self._process_tokens_async(
                parsed["content"],
589
                lora_request=lora_request,
590
                mm_hash_overrides=mm_hash_overrides,
591
            )
592
593
594
595
        if parsed["type"] == "text":
            return await self._process_text_async(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
596
                lora_request=lora_request,
597
                mm_hash_overrides=mm_hash_overrides,
598
599
600
601
            )
        if parsed["type"] == "str":
            return await self._process_text_async(
                TextPrompt(prompt=parsed["content"]),
602
                tokenization_kwargs=tokenization_kwargs,
603
                lora_request=lora_request,
604
                mm_hash_overrides=mm_hash_overrides,
605
            )
606

607
608
        assert_never(parsed)

609
610
    def _build_enc_dec_llm_inputs(
        self,
611
612
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
613
    ) -> EncoderDecoderInputs:
614
615
616
617
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
618

619
620
621
622
623
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
624

625
        if decoder_inputs is None:
626
627
628
629
630
631
632
633
634
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
635
            decoder_inputs = token_inputs(dec_token_ids)
636
        else:
637
638
639
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
640
641
642
643

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
644

645
        return EncoderDecoderInputs(
646
647
            encoder=encoder_inputs,
            decoder=decoder_inputs,
648
649
        )

650
    def _split_enc_dec_mm_inputs(
651
        self,
652
653
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
654
    ) -> tuple[SingletonInputs, SingletonInputs]:
655
656
657
658
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

674
675
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
676
677
678
679
680
681
682

        if inputs["type"] == "multimodal":  # Multimodal data inputs
            if not ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs):
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
683
            inputs = cast(MultiModalEncDecInputs, inputs)
684

685
686
687
688
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
689

690
691
692
693
694
695
696
697
698
699
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt=decoder_prompt_inputs.get("prompt", ""),
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
700
701
                decoder_inputs["cache_salt"] = cache_salt

702
        elif inputs["type"] == "token":  # Text-only inputs
703
704
705
706
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
707

708
709
        return encoder_inputs, decoder_inputs

710
711
    def _process_encoder_decoder_prompt(
        self,
712
        prompt: PromptType,
713
        tokenization_kwargs: Optional[dict[str, Any]] = None,
714
        *,
715
716
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
717
    ) -> EncoderDecoderInputs:
718
        """
719
        For encoder/decoder models only:
720
721
722
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
741

742
743
        Arguments:

744
        * prompt: an input prompt
745
746
747

        Returns:

748
749
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
750
        """
751
752
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
753

754
        if is_explicit_encoder_decoder_prompt(prompt):
755
            encoder_inputs = self._prompt_to_llm_inputs(
756
757
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
758
                mm_hash_overrides=mm_hash_overrides,
759
            )
760
            if (decoder_input := prompt["decoder_prompt"]) is None:
761
                decoder_inputs = None
762
            else:
763
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
764
765
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
766
            if self.model_config.is_multimodal_model:
767
                encoder_inputs, decoder_inputs = (
768
769
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
770
        else:
771
772
773
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
774
                mm_hash_overrides=mm_hash_overrides,
775
            )
776
            if self.model_config.is_multimodal_model:
777
778
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
779
                    self._split_enc_dec_mm_inputs(inputs))
780
781
782
            else:
                encoder_inputs = inputs
                decoder_inputs = None
783
784

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
785
786
787

    async def _process_encoder_decoder_prompt_async(
        self,
788
        prompt: PromptType,
789
        tokenization_kwargs: Optional[dict[str, Any]] = None,
790
        *,
791
792
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
793
    ) -> EncoderDecoderInputs:
794
795
796
797
        """
        Async version of
        [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
        """
798
799
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
800

801
        if is_explicit_encoder_decoder_prompt(prompt):
802
            encoder_task = self._prompt_to_llm_inputs_async(
803
804
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
805
                mm_hash_overrides=mm_hash_overrides,
806
            )
807

808
            if (decoder_input := prompt["decoder_prompt"]) is None:
809
810
                encoder_inputs = await encoder_task
                decoder_inputs = None
811
            else:
812
813
814
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
815
                    mm_hash_overrides=mm_hash_overrides,
816
                )
817

818
                encoder_inputs, decoder_inputs = await asyncio.gather(
819
                    encoder_task, decoder_task)
820
821
822

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
823
            if self.model_config.is_multimodal_model:
824
                encoder_inputs, decoder_inputs = (
825
826
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
827
        else:
828
829
830
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
831
                mm_hash_overrides=mm_hash_overrides,
832
            )
833
            if self.model_config.is_multimodal_model:
834
835
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
836
                    self._split_enc_dec_mm_inputs(inputs))
837
838
839
            else:
                encoder_inputs = inputs
                decoder_inputs = None
840
841

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
842
843
844

    def _build_decoder_only_llm_inputs(
        self,
845
        prompt_inputs: DecoderOnlyInputs,
846
    ) -> DecoderOnlyInputs:
847
848
849
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
850

851
        return prompt_inputs
852
853
854

    def _process_decoder_only_prompt(
        self,
855
        prompt: SingletonPrompt,
856
        tokenization_kwargs: Optional[dict[str, Any]] = None,
857
        lora_request: Optional[LoRARequest] = None,
858
        *,
859
860
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
861
    ) -> DecoderOnlyInputs:
862
        """
863
        For decoder-only models:
864
865
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
866
867
868

        Arguments:

869
        * prompt: input prompt
870
871
872
873
        * lora_request

        Returns:

874
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
875
        """
876

877
        prompt_comps = self._prompt_to_llm_inputs(
878
            prompt,
879
            tokenization_kwargs=tokenization_kwargs,
880
            lora_request=lora_request,
881
            mm_hash_overrides=mm_hash_overrides,
882
883
        )

884
        return self._build_decoder_only_llm_inputs(prompt_comps)
885
886
887

    async def _process_decoder_only_prompt_async(
        self,
888
        prompt: SingletonPrompt,
889
        tokenization_kwargs: Optional[dict[str, Any]] = None,
890
        lora_request: Optional[LoRARequest] = None,
891
        *,
892
893
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
894
    ) -> DecoderOnlyInputs:
895
896
897
898
        """
        Async version of
        [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
        """
899
        prompt_comps = await self._prompt_to_llm_inputs_async(
900
            prompt,
901
            tokenization_kwargs=tokenization_kwargs,
902
            lora_request=lora_request,
903
            mm_hash_overrides=mm_hash_overrides,
904
905
        )

906
        return self._build_decoder_only_llm_inputs(prompt_comps)
907
908
909

    def preprocess(
        self,
910
        prompt: PromptType,
911
        tokenization_kwargs: Optional[dict[str, Any]] = None,
912
        lora_request: Optional[LoRARequest] = None,
913
        *,
914
915
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
916
    ) -> ProcessorInputs:
917
        """Preprocess the input prompt."""
918
        if self.model_config.is_encoder_decoder:
919
            # Encoder-decoder model requires special mapping of
920
            # input prompts to encoder & decoder.
921
            return self._process_encoder_decoder_prompt(
922
923
                prompt,
                tokenization_kwargs,
924
                mm_hash_overrides=mm_hash_overrides,
925
            )
926

927
        if is_explicit_encoder_decoder_prompt(prompt):
928
929
930
931
932
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
933
            prompt,
934
            tokenization_kwargs=tokenization_kwargs,
935
            lora_request=lora_request,
936
            mm_hash_overrides=mm_hash_overrides,
937
938
939
940
        )

    async def preprocess_async(
        self,
941
        prompt: PromptType,
942
        tokenization_kwargs: Optional[dict[str, Any]] = None,
943
        lora_request: Optional[LoRARequest] = None,
944
        *,
945
946
        mm_hash_overrides: Optional[Union[dict[str, list[str]],
                                          MultiModalUUIDDict]] = None,
947
    ) -> ProcessorInputs:
948
949
950
951
        """
        Async version of
        [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
        """
952
        if self.model_config.is_encoder_decoder:
953
            # Encoder-decoder model requires special mapping of
954
955
956
957
            # input prompts to encoder & decoder.
            return await self._process_encoder_decoder_prompt_async(
                prompt,
                tokenization_kwargs,
958
                mm_hash_overrides=mm_hash_overrides,
959
            )
960

961
        if is_explicit_encoder_decoder_prompt(prompt):
962
963
964
965
966
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
967
            prompt,
968
            tokenization_kwargs=tokenization_kwargs,
969
            lora_request=lora_request,
970
            mm_hash_overrides=mm_hash_overrides,
971
        )
972
973
974
975

    def clear_cache(self) -> None:
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()
976
977
978
979
980
981
982
983
984
985
986
987


# Helper function to validate that a nested dictionary contains
# only strings or list of strings as the leaf values.
def contains_only_strings(obj: object):
    if isinstance(obj, str):
        return True
    if isinstance(obj, list):
        return all(isinstance(x, str) for x in obj)
    if isinstance(obj, dict):
        return all(contains_only_strings(v) for v in obj.values())
    return False