preprocess.py 34.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Mapping
6
from typing import Any, Optional, Union, cast
7
8
9
10
11
12

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
15
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                                    MultiModalInputs)
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
19

20
21
22
23
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
                   EncoderDecoderInputs, ProcessorInputs, PromptType,
                   SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
                   TokensPrompt, embeds_inputs, token_inputs)
24
25
26
27
28
29
30
31
32
33
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
34
        tokenizer: Optional[TokenizerGroup],
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
37
38
39
40
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
41
        self.mm_registry = mm_registry
42

43
    def get_tokenizer_group(self) -> TokenizerGroup:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
71
        """
72
73
74
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
75
        """
76

77
        if not self.model_config.is_encoder_decoder:
78
79
80
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
81
82
            return None

83
        if self.model_config is None or self.model_config.hf_config is None:
84
85
86
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
87
88
89
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
90
                                     "decoder_start_token_id", None)
91
        if dec_start_token_id is None:
92
93
94
95
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
96
97
98
99
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

100
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
101
        """
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
122
        other models may have different or more
123
124
125
126
127
128
129
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
130
        """
131
132
133
134
135
136
137

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
138
139
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
140
141
142
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

143
144
145
146
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

165
166
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
167
168
169
170
171
172
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _apply_prompt_adapter(
        self,
173
        prompt_token_ids: list[int],
174
        prompt_adapter_request: Optional[PromptAdapterRequest],
175
    ) -> list[int]:
176
177
178
179
180
181
182
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)

        return prompt_token_ids

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

200
201
202
203
    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
204
        tokenization_kwargs: Optional[dict[str, Any]] = None,
205
    ) -> list[int]:
206
207
208
209
210
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
211
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
212

213
        encoder_config = self.model_config.encoder_config
214

215
        if encoder_config and encoder_config.get("do_lower_case", False):
216
217
            prompt = prompt.lower()

218
219
220
221
222
        if self.model_config.tokenizer_mode == "cpm":
                return [tokenizer.bos_id] + tokenizer.encode(prompt)
        else:
            return tokenizer.encode(prompt=prompt,
                                    lora_request=lora_request,
zhuwenwen's avatar
zhuwenwen committed
223
                                    **tokenization_kwargs)
224
225
226
227
228

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
229
        tokenization_kwargs: Optional[dict[str, Any]] = None,
230
    ) -> list[int]:
231
232
233
234
        """
        Async version of
        [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
        """
235
        tokenizer = self.get_tokenizer_group()
236
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
237
238
239
240

        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    def _get_mm_tokenizer(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return tokenizer_group.get_lora_tokenizer(lora_request)

    async def _get_mm_tokenizer_async(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return await tokenizer_group.get_lora_tokenizer_async(lora_request)
265

266
267
    def _process_multimodal(
        self,
268
        prompt: Union[str, list[int]],
269
270
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
271
272
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
273
        return_mm_hashes: bool = False,
274
    ) -> MultiModalInputs:
275
276
277
278
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
279
        tokenizer = self._get_mm_tokenizer(lora_request)
280

281
282
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
283
284
285
286

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

287
288
289
290
291
        return mm_processor.apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs=mm_processor_kwargs,
                                  tokenization_kwargs=tokenization_kwargs,
                                  return_mm_hashes=return_mm_hashes)
292
293
294

    async def _process_multimodal_async(
        self,
295
        prompt: Union[str, list[int]],
296
297
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
298
299
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
300
        return_mm_hashes: bool = False,
301
    ) -> MultiModalInputs:
302
303
304
305
        """
        Async version of
        [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
        """
306
        tokenizer = await self._get_mm_tokenizer_async(lora_request)
307

308
309
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
310
311
312
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

313
314
315
316
317
        return mm_processor.apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs=mm_processor_kwargs,
                                  tokenization_kwargs=tokenization_kwargs,
                                  return_mm_hashes=return_mm_hashes)
318

319
    def _process_embeds(
320
        self,
321
322
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
323
324
325
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
326

327
        prompt_embeds = parsed_content["prompt_embeds"]
328

329
330
331
332
333
334
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)
335

336
337
338
        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")
339

340
341
        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))
342

343
344
345
346
347
348
349
350
351
    async def _process_embeds_async(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        return self._process_embeds(parsed_content)

    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
352
        tokenization_kwargs: Optional[dict[str, Any]] = None,
353
354
355
356
357
358
359
360
361
362
363
364
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
365
                tokenization_kwargs=tokenization_kwargs,
366
                lora_request=lora_request,
367
                return_mm_hashes=return_mm_hashes,
368
            )
369
        else:
370
            inputs = token_inputs(
371
                prompt_token_ids=prompt_token_ids,
372
                token_type_ids=token_type_ids,
373
374
            )

375
376
377
378
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
379

380
381
382
    async def _process_tokens_async(
        self,
        parsed_content: TokensPrompt,
383
        tokenization_kwargs: Optional[dict[str, Any]] = None,
384
385
386
387
388
389
390
391
392
393
394
395
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
396
                tokenization_kwargs=tokenization_kwargs,
397
398
399
400
401
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            inputs = token_inputs(
402
                prompt_token_ids=prompt_token_ids,
403
                token_type_ids=token_type_ids,
404
405
            )

406
407
408
409
410
411
412
413
414
415
416
417
418
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]
419

420
421
422
423
424
425
        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
426
                tokenization_kwargs=tokenization_kwargs,
427
428
429
430
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
431
            prompt_token_ids = self._tokenize_prompt(
432
                prompt_text,
433
                lora_request=lora_request,
434
                tokenization_kwargs=tokenization_kwargs,
435
            )
436
            inputs = token_inputs(
437
438
439
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )
440

441
442
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt
443

444
        return inputs
445

446
    async def _process_text_async(
447
        self,
448
449
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
450
        lora_request: Optional[LoRARequest] = None,
451
        return_mm_hashes: bool = False,
452
453
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]
454

455
456
457
458
459
460
        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
461
                tokenization_kwargs=tokenization_kwargs,
462
463
464
465
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
466
            prompt_token_ids = await self._tokenize_prompt_async(
467
                prompt_text,
468
                lora_request=lora_request,
469
                tokenization_kwargs=tokenization_kwargs,
470
            )
471
            inputs = token_inputs(
472
473
474
475
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

476
477
        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt
478

479
        return inputs
480
481

    def _prompt_to_llm_inputs(
482
        self,
483
        prompt: SingletonPrompt,
484
        tokenization_kwargs: Optional[dict[str, Any]] = None,
485
        lora_request: Optional[LoRARequest] = None,
486
        return_mm_hashes: bool = False,
487
    ) -> SingletonInputs:
488
489
490
        """
        Extract the singleton inputs from a prompt.

491
492
        Arguments:

493
        * prompt: single encoder or decoder input prompt
494
        * lora_request: this is only valid for decoder prompts
495
        * return_mm_hashes: whether to return multimodal hashes
496
497
498

        Returns:

499
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
500
        """
501
        parsed = parse_singleton_prompt(prompt)
502
503

        if parsed["type"] == "embeds":
504
505
506
507
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
508
                lora_request=lora_request,
509
                return_mm_hashes=return_mm_hashes,
510
            )
511
512
513
514
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
515
                lora_request=lora_request,
516
517
518
519
520
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
521
                tokenization_kwargs=tokenization_kwargs,
522
                lora_request=lora_request,
523
                return_mm_hashes=return_mm_hashes,
524
            )
525

526
527
        assert_never(parsed)

528
    async def _prompt_to_llm_inputs_async(
529
        self,
530
        prompt: SingletonPrompt,
531
        tokenization_kwargs: Optional[dict[str, Any]] = None,
532
        lora_request: Optional[LoRARequest] = None,
533
        return_mm_hashes: bool = False,
534
    ) -> SingletonInputs:
535
536
537
538
        """
        Async version of
        [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
        """
539
        parsed = parse_singleton_prompt(prompt)
540

541
        if parsed["type"] == "embeds":
542
543
544
545
            return await self._process_embeds_async(parsed["content"])
        if parsed["type"] == "tokens":
            return await self._process_tokens_async(
                parsed["content"],
546
547
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
548
            )
549
550
551
552
        if parsed["type"] == "text":
            return await self._process_text_async(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
553
                lora_request=lora_request,
554
555
556
557
558
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return await self._process_text_async(
                TextPrompt(prompt=parsed["content"]),
559
                tokenization_kwargs=tokenization_kwargs,
560
561
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
562
            )
563

564
        assert_never(parsed)
565
566
567

    def _build_enc_dec_llm_inputs(
        self,
568
569
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
570
    ) -> EncoderDecoderInputs:
571
572
573
574
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
575

576
577
578
579
580
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
581
582

        if decoder_inputs is None:
583
584
585
586
587
588
589
590
591
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
592
            decoder_inputs = token_inputs(dec_token_ids)
593
        else:
594
595
596
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
597
598
599
600

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
601

602
        return EncoderDecoderInputs(
603
604
            encoder=encoder_inputs,
            decoder=decoder_inputs,
605
606
        )

607
    def _split_enc_dec_mm_inputs(
608
        self,
609
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
610
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
611
    ) -> tuple[SingletonInputs, SingletonInputs]:
612
613
614
615
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

631
632
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
633
634
635
636
637
638
639

        if inputs["type"] == "multimodal":  # Multimodal data inputs
            if not ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs):
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
640
            inputs = cast(MultiModalEncDecInputs, inputs)
641

642
643
644
645
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
646

647
648
649
650
651
652
653
654
655
656
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt=decoder_prompt_inputs.get("prompt", ""),
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
657
658
                decoder_inputs["cache_salt"] = cache_salt

659
        elif inputs["type"] == "token":  # Text-only inputs
660
661
662
663
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
664

665
666
        return encoder_inputs, decoder_inputs

667
668
    def _process_encoder_decoder_prompt(
        self,
669
        prompt: PromptType,
670
        tokenization_kwargs: Optional[dict[str, Any]] = None,
671
    ) -> EncoderDecoderInputs:
672
        """
673
        For encoder/decoder models only:
674
675
676
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
695

696
697
        Arguments:

698
        * prompt: an input prompt
699
700
701

        Returns:

702
703
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
704
        """
705
706
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
707

708
        if is_explicit_encoder_decoder_prompt(prompt):
709
            encoder_inputs = self._prompt_to_llm_inputs(
710
711
712
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
713
            if (decoder_input := prompt["decoder_prompt"]) is None:
714
                decoder_inputs = None
715
            else:
716
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
717
718
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
719
            if self.model_config.is_multimodal_model:
720
                encoder_inputs, decoder_inputs = (
721
722
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
723
        else:
724
725
726
727
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
728
            if self.model_config.is_multimodal_model:
729
730
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
731
                    self._split_enc_dec_mm_inputs(inputs))
732
733
734
            else:
                encoder_inputs = inputs
                decoder_inputs = None
735
736

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
737
738
739

    async def _process_encoder_decoder_prompt_async(
        self,
740
        prompt: PromptType,
741
        tokenization_kwargs: Optional[dict[str, Any]] = None,
742
    ) -> EncoderDecoderInputs:
743
744
745
746
        """
        Async version of
        [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
        """
747
748
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
749

750
        if is_explicit_encoder_decoder_prompt(prompt):
751
            encoder_task = self._prompt_to_llm_inputs_async(
752
753
754
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
755

756
            if (decoder_input := prompt["decoder_prompt"]) is None:
757
758
                encoder_inputs = await encoder_task
                decoder_inputs = None
759
            else:
760
761
762
763
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
                )
764

765
                encoder_inputs, decoder_inputs = await asyncio.gather(
766
                    encoder_task, decoder_task)
767
768
769

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
770
            if self.model_config.is_multimodal_model:
771
                encoder_inputs, decoder_inputs = (
772
773
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
774
        else:
775
776
777
778
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
779
            if self.model_config.is_multimodal_model:
780
781
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
782
                    self._split_enc_dec_mm_inputs(inputs))
783
784
785
            else:
                encoder_inputs = inputs
                decoder_inputs = None
786
787

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
788
789
790

    def _build_decoder_only_llm_inputs(
        self,
791
        prompt_inputs: DecoderOnlyInputs,
792
        prompt_adapter_request: Optional[PromptAdapterRequest],
793
    ) -> DecoderOnlyInputs:
794
795
796
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
797
798
799
800
            prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
                prompt_inputs["prompt_token_ids"],
                prompt_adapter_request=prompt_adapter_request,
            )
801

802
        return prompt_inputs
803
804
805

    def _process_decoder_only_prompt(
        self,
806
        prompt: SingletonPrompt,
807
        tokenization_kwargs: Optional[dict[str, Any]] = None,
808
809
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
810
        return_mm_hashes: bool = False,
811
    ) -> DecoderOnlyInputs:
812
        """
813
        For decoder-only models:
814
815
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
816
817
818

        Arguments:

819
        * prompt: input prompt
820
821
        * lora_request
        * prompt_adapter_request
822
        * return_mm_hashes
823
824
825

        Returns:

826
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
827
        """
828

829
        prompt_comps = self._prompt_to_llm_inputs(
830
            prompt,
831
            tokenization_kwargs=tokenization_kwargs,
832
            lora_request=lora_request,
833
            return_mm_hashes=return_mm_hashes,
834
835
836
837
838
839
840
841
842
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def _process_decoder_only_prompt_async(
        self,
843
        prompt: SingletonPrompt,
844
        tokenization_kwargs: Optional[dict[str, Any]] = None,
845
846
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
847
        return_mm_hashes: bool = False,
848
    ) -> DecoderOnlyInputs:
849
850
851
852
        """
        Async version of
        [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
        """
853
        prompt_comps = await self._prompt_to_llm_inputs_async(
854
            prompt,
855
            tokenization_kwargs=tokenization_kwargs,
856
            lora_request=lora_request,
857
            return_mm_hashes=return_mm_hashes,
858
859
860
861
862
863
864
865
866
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    def preprocess(
        self,
867
        prompt: PromptType,
868
        tokenization_kwargs: Optional[dict[str, Any]] = None,
869
870
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
871
        return_mm_hashes: bool = False,
872
    ) -> ProcessorInputs:
873
        """Preprocess the input prompt."""
874
        if self.model_config.is_encoder_decoder:
875
876
877
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
878
879
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
880
881
            return self._process_encoder_decoder_prompt(
                prompt, tokenization_kwargs)
882

883
        if is_explicit_encoder_decoder_prompt(prompt):
884
885
886
887
888
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
889
            prompt,
890
            tokenization_kwargs=tokenization_kwargs,
891
892
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
893
            return_mm_hashes=return_mm_hashes,
894
895
896
897
        )

    async def preprocess_async(
        self,
898
        prompt: PromptType,
899
        tokenization_kwargs: Optional[dict[str, Any]] = None,
900
901
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
902
        return_mm_hashes: bool = False,
903
    ) -> ProcessorInputs:
904
905
906
907
        """
        Async version of
        [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
        """
908
        if self.model_config.is_encoder_decoder:
909
910
911
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
912
913
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
914
            return await self._process_encoder_decoder_prompt_async(prompt)
915

916
        if is_explicit_encoder_decoder_prompt(prompt):
917
918
919
920
921
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
922
            prompt,
923
            tokenization_kwargs=tokenization_kwargs,
924
925
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
926
            return_mm_hashes=return_mm_hashes,
927
        )