preprocess.py 33.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from collections.abc import Mapping
5
from typing import Any, Optional, Union, cast
6
7
8
9
10
11

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
14
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                                    MultiModalInputs)
15
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
18

19
20
21
22
23
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
                   EncoderDecoderInputs, ProcessorInputs, PromptType,
                   SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
                   TokensPrompt, embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
24
25
26
27
28
29
30
31
32

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
33
        tokenizer: Optional[TokenizerGroup],
34
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
35
36
37
38
39
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
40
        self.mm_registry = mm_registry
41

42
    def get_tokenizer_group(self) -> TokenizerGroup:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
70
        """
71
72
73
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
74
        """
75

76
        if not self.model_config.is_encoder_decoder:
77
78
79
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
80
81
            return None

82
        if self.model_config is None or self.model_config.hf_config is None:
83
84
85
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
86
87
88
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
89
                                     "decoder_start_token_id", None)
90
        if dec_start_token_id is None:
91
92
93
94
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
95
96
97
98
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

99
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
100
        """
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
121
        other models may have different or more
122
123
124
125
126
127
128
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
129
        """
130
131
132
133
134
135
136

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
137
138
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
139
140
141
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

142
143
144
145
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

164
165
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
166
167
168
169
170
171
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _apply_prompt_adapter(
        self,
172
        prompt_token_ids: list[int],
173
        prompt_adapter_request: Optional[PromptAdapterRequest],
174
    ) -> list[int]:
175
176
177
178
179
180
181
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)

        return prompt_token_ids

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

199
200
201
202
    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
203
        tokenization_kwargs: Optional[dict[str, Any]] = None,
204
    ) -> list[int]:
205
206
207
208
209
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
210
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
211

212
        encoder_config = self.model_config.encoder_config
213

214
        if encoder_config and encoder_config.get("do_lower_case", False):
215
216
            prompt = prompt.lower()

217
        return tokenizer.encode(prompt=prompt,
218
                                lora_request=lora_request,
219
                                **tokenization_kwargs)
220
221
222
223
224

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
225
        tokenization_kwargs: Optional[dict[str, Any]] = None,
226
    ) -> list[int]:
227
228
229
230
        """
        Async version of
        [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
        """
231
        tokenizer = self.get_tokenizer_group()
232
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
233
234
235
236

        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def _get_mm_tokenizer(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return tokenizer_group.get_lora_tokenizer(lora_request)

    async def _get_mm_tokenizer_async(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return await tokenizer_group.get_lora_tokenizer_async(lora_request)

262
263
    def _process_multimodal(
        self,
264
        prompt: Union[str, list[int]],
265
266
267
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
268
        return_mm_hashes: bool = False,
269
    ) -> MultiModalInputs:
270
271
272
273
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
274
        tokenizer = self._get_mm_tokenizer(lora_request)
275

276
277
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
278
279
280
281

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

282
283
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
284
285
286

    async def _process_multimodal_async(
        self,
287
        prompt: Union[str, list[int]],
288
289
290
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
291
        return_mm_hashes: bool = False,
292
    ) -> MultiModalInputs:
293
294
295
296
        """
        Async version of
        [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
        """
297
        tokenizer = await self._get_mm_tokenizer_async(lora_request)
298

299
300
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
301
302
303
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

304
305
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
306

307
308
309
310
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
311
312
313
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
314
315

        prompt_embeds = parsed_content["prompt_embeds"]
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")

        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))

    async def _process_embeds_async(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        return self._process_embeds(parsed_content)

    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
355
        else:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            inputs = token_inputs(
                prompt_token_ids=prompt_token_ids,
                token_type_ids=token_type_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    async def _process_tokens_async(
        self,
        parsed_content: TokensPrompt,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            inputs = token_inputs(
                prompt_token_ids=prompt_token_ids,
                token_type_ids=token_type_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    async def _process_text_async(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
462

463
    def _prompt_to_llm_inputs(
464
        self,
465
        prompt: SingletonPrompt,
466
        tokenization_kwargs: Optional[dict[str, Any]] = None,
467
        lora_request: Optional[LoRARequest] = None,
468
        return_mm_hashes: bool = False,
469
    ) -> SingletonInputs:
470
471
        """
        Extract the singleton inputs from a prompt.
472
473
474

        Arguments:

475
        * prompt: single encoder or decoder input prompt
476
        * lora_request: this is only valid for decoder prompts
477
        * return_mm_hashes: whether to return multimodal hashes
478
479
480

        Returns:

481
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
482
        """
483
        parsed = parse_singleton_prompt(prompt)
484
485

        if parsed["type"] == "embeds":
486
487
488
489
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
490
                lora_request=lora_request,
491
                return_mm_hashes=return_mm_hashes,
492
            )
493
494
495
496
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
497
                lora_request=lora_request,
498
499
500
501
502
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
503
                tokenization_kwargs=tokenization_kwargs,
504
505
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
506
            )
507

508
509
        assert_never(parsed)

510
    async def _prompt_to_llm_inputs_async(
511
        self,
512
        prompt: SingletonPrompt,
513
        tokenization_kwargs: Optional[dict[str, Any]] = None,
514
        lora_request: Optional[LoRARequest] = None,
515
        return_mm_hashes: bool = False,
516
    ) -> SingletonInputs:
517
518
519
520
        """
        Async version of
        [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
        """
521
        parsed = parse_singleton_prompt(prompt)
522

523
        if parsed["type"] == "embeds":
524
525
526
527
            return await self._process_embeds_async(parsed["content"])
        if parsed["type"] == "tokens":
            return await self._process_tokens_async(
                parsed["content"],
528
529
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
530
            )
531
532
533
534
        if parsed["type"] == "text":
            return await self._process_text_async(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
535
                lora_request=lora_request,
536
537
538
539
540
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return await self._process_text_async(
                TextPrompt(prompt=parsed["content"]),
541
                tokenization_kwargs=tokenization_kwargs,
542
543
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
544
            )
545

546
547
        assert_never(parsed)

548
549
    def _build_enc_dec_llm_inputs(
        self,
550
551
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
552
    ) -> EncoderDecoderInputs:
553
554
555
556
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
557

558
559
560
561
562
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
563

564
        if decoder_inputs is None:
565
566
567
568
569
570
571
572
573
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
574
            decoder_inputs = token_inputs(dec_token_ids)
575
        else:
576
577
578
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
579
580
581
582

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
583

584
        return EncoderDecoderInputs(
585
586
            encoder=encoder_inputs,
            decoder=decoder_inputs,
587
588
        )

589
    def _split_enc_dec_mm_inputs(
590
        self,
591
592
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
593
    ) -> tuple[SingletonInputs, SingletonInputs]:
594
595
596
597
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

613
614
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
615
616
617
618
619
620
621

        if inputs["type"] == "multimodal":  # Multimodal data inputs
            if not ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs):
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
622
            inputs = cast(MultiModalEncDecInputs, inputs)
623

624
625
626
627
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
628

629
630
631
632
633
634
635
636
637
638
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt=decoder_prompt_inputs.get("prompt", ""),
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
639
640
                decoder_inputs["cache_salt"] = cache_salt

641
        elif inputs["type"] == "token":  # Text-only inputs
642
643
644
645
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
646

647
648
        return encoder_inputs, decoder_inputs

649
650
    def _process_encoder_decoder_prompt(
        self,
651
        prompt: PromptType,
652
        tokenization_kwargs: Optional[dict[str, Any]] = None,
653
    ) -> EncoderDecoderInputs:
654
        """
655
        For encoder/decoder models only:
656
657
658
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
677

678
679
        Arguments:

680
        * prompt: an input prompt
681
682
683

        Returns:

684
685
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
686
        """
687
688
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
689

690
        if is_explicit_encoder_decoder_prompt(prompt):
691
            encoder_inputs = self._prompt_to_llm_inputs(
692
693
694
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
695
            if (decoder_input := prompt["decoder_prompt"]) is None:
696
                decoder_inputs = None
697
            else:
698
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
699
700
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
701
            if self.model_config.is_multimodal_model:
702
                encoder_inputs, decoder_inputs = (
703
704
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
705
        else:
706
707
708
709
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
710
            if self.model_config.is_multimodal_model:
711
712
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
713
                    self._split_enc_dec_mm_inputs(inputs))
714
715
716
            else:
                encoder_inputs = inputs
                decoder_inputs = None
717
718

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
719
720
721

    async def _process_encoder_decoder_prompt_async(
        self,
722
        prompt: PromptType,
723
        tokenization_kwargs: Optional[dict[str, Any]] = None,
724
    ) -> EncoderDecoderInputs:
725
726
727
728
        """
        Async version of
        [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
        """
729
730
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
731

732
        if is_explicit_encoder_decoder_prompt(prompt):
733
            encoder_task = self._prompt_to_llm_inputs_async(
734
735
736
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
737

738
            if (decoder_input := prompt["decoder_prompt"]) is None:
739
740
                encoder_inputs = await encoder_task
                decoder_inputs = None
741
            else:
742
743
744
745
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
                )
746

747
                encoder_inputs, decoder_inputs = await asyncio.gather(
748
                    encoder_task, decoder_task)
749
750
751

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
752
            if self.model_config.is_multimodal_model:
753
                encoder_inputs, decoder_inputs = (
754
755
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
756
        else:
757
758
759
760
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
761
            if self.model_config.is_multimodal_model:
762
763
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
764
                    self._split_enc_dec_mm_inputs(inputs))
765
766
767
            else:
                encoder_inputs = inputs
                decoder_inputs = None
768
769

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
770
771
772

    def _build_decoder_only_llm_inputs(
        self,
773
        prompt_inputs: DecoderOnlyInputs,
774
        prompt_adapter_request: Optional[PromptAdapterRequest],
775
    ) -> DecoderOnlyInputs:
776
777
778
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
779
780
781
782
            prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
                prompt_inputs["prompt_token_ids"],
                prompt_adapter_request=prompt_adapter_request,
            )
783

784
        return prompt_inputs
785
786
787

    def _process_decoder_only_prompt(
        self,
788
        prompt: SingletonPrompt,
789
        tokenization_kwargs: Optional[dict[str, Any]] = None,
790
791
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
792
        return_mm_hashes: bool = False,
793
    ) -> DecoderOnlyInputs:
794
        """
795
        For decoder-only models:
796
797
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
798
799
800

        Arguments:

801
        * prompt: input prompt
802
803
        * lora_request
        * prompt_adapter_request
804
        * return_mm_hashes
805
806
807

        Returns:

808
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
809
        """
810

811
        prompt_comps = self._prompt_to_llm_inputs(
812
            prompt,
813
            tokenization_kwargs=tokenization_kwargs,
814
            lora_request=lora_request,
815
            return_mm_hashes=return_mm_hashes,
816
817
818
819
820
821
822
823
824
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def _process_decoder_only_prompt_async(
        self,
825
        prompt: SingletonPrompt,
826
        tokenization_kwargs: Optional[dict[str, Any]] = None,
827
828
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
829
        return_mm_hashes: bool = False,
830
    ) -> DecoderOnlyInputs:
831
832
833
834
        """
        Async version of
        [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
        """
835
        prompt_comps = await self._prompt_to_llm_inputs_async(
836
            prompt,
837
            tokenization_kwargs=tokenization_kwargs,
838
            lora_request=lora_request,
839
            return_mm_hashes=return_mm_hashes,
840
841
842
843
844
845
846
847
848
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    def preprocess(
        self,
849
        prompt: PromptType,
850
        tokenization_kwargs: Optional[dict[str, Any]] = None,
851
852
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
853
        return_mm_hashes: bool = False,
854
    ) -> ProcessorInputs:
855
        """Preprocess the input prompt."""
856
        if self.model_config.is_encoder_decoder:
857
858
859
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
860
861
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
862
            return self._process_encoder_decoder_prompt(prompt)
863

864
        if is_explicit_encoder_decoder_prompt(prompt):
865
866
867
868
869
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
870
            prompt,
871
            tokenization_kwargs=tokenization_kwargs,
872
873
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
874
            return_mm_hashes=return_mm_hashes,
875
876
877
878
        )

    async def preprocess_async(
        self,
879
        prompt: PromptType,
880
        tokenization_kwargs: Optional[dict[str, Any]] = None,
881
882
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
883
        return_mm_hashes: bool = False,
884
    ) -> ProcessorInputs:
885
886
887
888
        """
        Async version of
        [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
        """
889
        if self.model_config.is_encoder_decoder:
890
891
892
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
893
894
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
895
            return await self._process_encoder_decoder_prompt_async(prompt)
896

897
        if is_explicit_encoder_decoder_prompt(prompt):
898
899
900
901
902
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
903
            prompt,
904
            tokenization_kwargs=tokenization_kwargs,
905
906
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
907
            return_mm_hashes=return_mm_hashes,
908
        )