preprocess.py 30.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from collections.abc import Mapping
5
from typing import Any, Optional, Union, cast
6
7
8

from typing_extensions import assert_never

9
from vllm import envs
10
11
12
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
14
15
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                                    MultiModalInputs)
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
18

19
20
21
22
23
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
                   ProcessorInputs, PromptType, SingletonInputs,
                   SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
                    ParsedTokensPrompt, is_embeds_inputs,
24
                    is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
25
26
27
28
29
30
31
32
33

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
34
        tokenizer: Optional[TokenizerGroup],
35
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
36
37
38
39
40
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
41
        self.mm_registry = mm_registry
42

43
    def get_tokenizer_group(self) -> TokenizerGroup:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

77
        if not self.model_config.is_encoder_decoder:
78
79
80
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
81
82
83
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
84
85
86
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
87
88
89
90
91
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
92
93
94
95
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
96
97
98
99
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

100
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
122
        other models may have different or more
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
138
139
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on

        https://github.com/huggingface/transformers/blob/
        4037a2b5b1278736e566aec12e169100275545ea/
        src/transformers/generation/utils.py

        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

168
169
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
170
171
172
173
174
175
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _apply_prompt_adapter(
        self,
176
        prompt_token_ids: list[int],
177
        prompt_adapter_request: Optional[PromptAdapterRequest],
178
    ) -> list[int]:
179
180
181
182
183
184
185
186
187
188
189
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)

        return prompt_token_ids

    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
190
        tokenization_kwargs: Optional[dict[str, Any]] = None,
191
    ) -> list[int]:
192
193
194
195
196
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
197
198
199
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

200
201
202
203
        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
204
            tokenization_kwargs["add_special_tokens"] = False
205
206
207
208
209
210

        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

211
        return tokenizer.encode(prompt=prompt,
212
                                lora_request=lora_request,
213
                                **tokenization_kwargs)
214
215
216
217
218

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
219
        tokenization_kwargs: Optional[dict[str, Any]] = None,
220
    ) -> list[int]:
221
222
        """Async version of :meth:`_tokenize_prompt`."""
        tokenizer = self.get_tokenizer_group()
223
224
225
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

226
227
228
229
        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
230
231
232
233
            tokenization_kwargs["add_special_tokens"] = False
        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
234

235
236
    def _process_multimodal(
        self,
237
        prompt: Union[str, list[int]],
238
239
240
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
241
        return_mm_hashes: bool = False,
242
    ) -> MultiModalInputs:
243
244
245
246
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
247
        # At the moment on model (PrithviGeoSpatialMAE) requires to be
248
        # initialized without a tokenizer while using also multi-modal input
249
        if not self.tokenizer:
250
            tokenizer = object()  # Dummy
251
252
253
        else:
            tokenizer_group = self.get_tokenizer_group()
            tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
254

255
256
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
257
258
259
260

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

261
262
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
263
264
265

    async def _process_multimodal_async(
        self,
266
        prompt: Union[str, list[int]],
267
268
269
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
270
        return_mm_hashes: bool = False,
271
    ) -> MultiModalInputs:
272
        """Async version of :meth:`_process_multimodal`."""
273
        # At the moment on model (PrithviGeoSpatialMAE) requires to be
274
        # initialized without a tokenizer while using also multi-modal input
275
        if not self.tokenizer:
276
            tokenizer = object()  # Dummy
277
278
279
280
        else:
            tokenizer_group = self.get_tokenizer_group()
            tokenizer = await tokenizer_group.get_lora_tokenizer_async(
                lora_request)
281

282
283
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
284
285
286
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

287
288
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
289

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
                                                    ParsedTextPrompt,
                                                    ParsedTokensPrompt]):
        prompt_text = None
        prompt_token_ids = None
        token_type_ids = None
        cache_salt = None

        if parsed_prompt["type"] == "str":
            prompt_text = parsed_prompt["content"]
        else:
            cache_salt = parsed_prompt["content"].get("cache_salt")
            if parsed_prompt["type"] == "text":
                prompt_text = parsed_prompt["content"]["prompt"]
            elif parsed_prompt["type"] == "tokens":
                prompt_token_ids = parsed_prompt["content"].get(
                    "prompt_token_ids")
                token_type_ids = parsed_prompt["content"].get("token_type_ids")
            else:
                assert_never(parsed_prompt)

        return prompt_text, prompt_token_ids, token_type_ids, cache_salt

313
    def _prompt_to_llm_inputs(
314
        self,
315
        prompt: SingletonPrompt,
316
        tokenization_kwargs: Optional[dict[str, Any]] = None,
317
        lora_request: Optional[LoRARequest] = None,
318
        return_mm_hashes: bool = False,
319
    ) -> SingletonInputs:
320
321
        """
        Extract the singleton inputs from a prompt.
322
323
324

        Arguments:

325
        * prompt: single encoder or decoder input prompt
326
        * lora_request: this is only valid for decoder prompts
327
        * return_mm_hashes: whether to return multimodal hashes
328
329
330

        Returns:

331
332
        * :class:`SingletonInputs` instance
        """
333
        parsed = parse_singleton_prompt(prompt)
334
335
336
337

        if parsed["type"] == "embeds":
            return self._process_prompt_embeds(parsed)

338
339
340
341
342
343
344
345
346
347
        prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
            self._get_prompt_data(parsed)

        # If multimodal data is present, process and return immediately
        if parsed["type"] != "str" and parsed["content"].get(
                "multi_modal_data") is not None:
            inputs = self._process_multimodal(
                prompt_text if prompt_text is not None else prompt_token_ids,
                parsed["content"]["multi_modal_data"],
                parsed["content"].get("mm_processor_kwargs"),
348
                lora_request=lora_request,
349
                return_mm_hashes=return_mm_hashes,
350
            )
351
352
353
            if cache_salt is not None:
                inputs["cache_salt"] = cache_salt
            return inputs
354

355
        if prompt_token_ids is None:
356
            prompt_token_ids = self._tokenize_prompt(
357
                prompt_text,
358
                lora_request=lora_request,
359
                tokenization_kwargs=tokenization_kwargs,
360
            )
361

362
363
364
365
366
367
        return token_inputs(
            prompt=prompt_text,
            prompt_token_ids=prompt_token_ids,
            token_type_ids=token_type_ids,
            cache_salt=cache_salt,
        )
368

369
370
        assert_never(parsed)

371
    async def _prompt_to_llm_inputs_async(
372
        self,
373
        prompt: SingletonPrompt,
374
        tokenization_kwargs: Optional[dict[str, Any]] = None,
375
        lora_request: Optional[LoRARequest] = None,
376
        return_mm_hashes: bool = False,
377
    ) -> SingletonInputs:
378
        """Async version of :meth:`_extract_prompt_components`."""
379
        parsed = parse_singleton_prompt(prompt)
380

381
382
383
        if parsed["type"] == "embeds":
            return self._process_prompt_embeds(parsed)

384
385
        prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
            self._get_prompt_data(parsed)
386

387
388
389
390
391
392
393
394
        if parsed["type"] != "str" and parsed["content"].get(
                "multi_modal_data") is not None:
            inputs = await self._process_multimodal_async(
                prompt_token_ids if prompt_text is None else prompt_text,
                parsed["content"]["multi_modal_data"],
                parsed["content"].get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
395
            )
396
397
398
            if cache_salt is not None:
                inputs["cache_salt"] = cache_salt
            return inputs
399

400
        if prompt_token_ids is None:
401
            prompt_token_ids = await self._tokenize_prompt_async(
402
                prompt_text,
403
                lora_request=lora_request,
404
                tokenization_kwargs=tokenization_kwargs,
405
            )
406

407
408
409
410
411
412
        return token_inputs(
            prompt=prompt_text,
            prompt_token_ids=prompt_token_ids,
            token_type_ids=token_type_ids,
            cache_salt=cache_salt,
        )
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    def _process_prompt_embeds(self,
                               parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
        if envs.VLLM_USE_V1:
            raise ValueError("prompt_embeds is only available in V0.")

        prompt_embeds_content = parsed["content"]

        prompt_embeds = prompt_embeds_content["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")

        return embeds_inputs(prompt_embeds=prompt_embeds)

        assert_never(parsed)

438
439
    def _build_enc_dec_llm_inputs(
        self,
440
441
        encoder_inputs: Union[TokenInputs, MultiModalInputs],
        decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
442
    ) -> EncoderDecoderInputs:
443
444
        if (encoder_inputs["type"] == "token"
                or encoder_inputs["type"] == "multimodal"):
445
446
            pass
        else:
447
            assert_never(encoder_inputs)  # type: ignore[arg-type]
448

449
450
451
        # Mypy does not correctly infer that EmbedsInputs is impossible
        assert "prompt_token_ids" in encoder_inputs

452
        if decoder_inputs is None:
453
454
455
456
457
458
459
460
461
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
462
            decoder_inputs = token_inputs(dec_token_ids)
463
464
        elif (decoder_inputs["type"] == "token"
              or decoder_inputs["type"] == "multimodal"):
465
466
467
468
469
470
471
472
            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids

            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
        else:
473
            assert_never(encoder_inputs)  # type: ignore[arg-type]
474

475
        return EncoderDecoderInputs(
476
477
            encoder=encoder_inputs,
            decoder=decoder_inputs,
478
479
        )

480
481
482
    def _separate_enc_dec_inputs_from_mm_processor_outputs(
        self,
        inputs: SingletonInputs,
483
484
        decoder_inputs_to_override: Optional[Union[TokenInputs,
                                                   MultiModalInputs]] = None,
485
    ) -> tuple[SingletonInputs, SingletonInputs]:
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
        if inputs["type"] == "multimodal":
            # Multimodal data inputs
            assert ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs)
            inputs = cast(MultiModalEncDecInputs, inputs)
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
            if decoder_inputs_to_override is not None:
                decoder_inputs = MultiModalInputs(
                    type="multimodal",
                    prompt=decoder_inputs_to_override.get("prompt", ""),
                    prompt_token_ids=decoder_inputs_to_override[
                        "prompt_token_ids"],
                    mm_kwargs=inputs["mm_kwargs"],
508
                    mm_hashes=inputs["mm_hashes"],
509
510
511
512
513
514
515
516
                    mm_placeholders=inputs["mm_placeholders"],
                )
            else:
                decoder_inputs = MultiModalInputs(
                    type="multimodal",
                    prompt=inputs["prompt"],
                    prompt_token_ids=inputs["prompt_token_ids"],
                    mm_kwargs=inputs["mm_kwargs"],
517
                    mm_hashes=inputs["mm_hashes"],
518
519
                    mm_placeholders=inputs["mm_placeholders"],
                )
520
521
522
523
524

            cache_salt = inputs.get("cache_salt")
            if cache_salt is not None:
                decoder_inputs["cache_salt"] = cache_salt

525
526
527
528
529
530
531
532
        elif inputs["type"] == "token":
            # Text-only inputs
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
        return encoder_inputs, decoder_inputs

533
534
    def _process_encoder_decoder_prompt(
        self,
535
        prompt: PromptType,
536
        tokenization_kwargs: Optional[dict[str, Any]] = None,
537
    ) -> EncoderDecoderInputs:
538
        """
539
        For encoder/decoder models only:
540
        Process an input prompt into an :class:`EncoderDecoderInputs` instance.
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
559

560
561
        Arguments:

562
        * prompt: an input prompt
563
564
565

        Returns:

566
        * :class:`EncoderDecoderInputs` instance
567
        """
568
569
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
570

571
        if is_explicit_encoder_decoder_prompt(prompt):
572
            encoder_inputs = self._prompt_to_llm_inputs(
573
574
575
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
576
            if (decoder_input := prompt["decoder_prompt"]) is None:
577
                decoder_inputs = None
578
            else:
579
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
580
581
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
582
            if self.model_config.is_multimodal_model:
583
584
                assert decoder_inputs is None or not is_embeds_inputs(
                    decoder_inputs)
585
586
587
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        encoder_inputs, decoder_inputs))
588
        else:
589
590
591
592
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
593
            if self.model_config.is_multimodal_model:
594
595
596
597
598
599
600
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        inputs))
            else:
                encoder_inputs = inputs
                decoder_inputs = None
601

602
603
604
605
        # Mypy does not do type inference well with TypedDicts with Literal
        # values.
        assert not is_embeds_inputs(encoder_inputs)
        assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
606
        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
607
608
609

    async def _process_encoder_decoder_prompt_async(
        self,
610
        prompt: PromptType,
611
        tokenization_kwargs: Optional[dict[str, Any]] = None,
612
    ) -> EncoderDecoderInputs:
613
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
614
615
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
616

617
        if is_explicit_encoder_decoder_prompt(prompt):
618
            encoder_task = self._prompt_to_llm_inputs_async(
619
620
621
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
622

623
            if (decoder_input := prompt["decoder_prompt"]) is None:
624
625
                encoder_inputs = await encoder_task
                decoder_inputs = None
626
            else:
627
628
629
630
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
                )
631

632
                encoder_inputs, decoder_inputs = await asyncio.gather(
633
                    encoder_task, decoder_task)
634
635
636

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
637
            if self.model_config.is_multimodal_model:
638
639
                assert decoder_inputs is None or not is_embeds_inputs(
                    decoder_inputs)
640
641
642
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        encoder_inputs, decoder_inputs))
643
        else:
644
645
646
647
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
648
            if self.model_config.is_multimodal_model:
649
650
651
652
653
654
655
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        inputs))
            else:
                encoder_inputs = inputs
                decoder_inputs = None
656

657
658
659
660
        # Mypy does not do type inference well with TypedDicts with Literal
        # values.
        assert not is_embeds_inputs(encoder_inputs)
        assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
661
        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
662
663
664

    def _build_decoder_only_llm_inputs(
        self,
665
        prompt_inputs: DecoderOnlyInputs,
666
        prompt_adapter_request: Optional[PromptAdapterRequest],
667
    ) -> DecoderOnlyInputs:
668
669
        if (prompt_inputs["type"] == "token"
                or prompt_inputs["type"] == "multimodal"):
670
671
672
            # Mypy does not do type inference well with typedicts and Literal
            # values
            assert not is_embeds_inputs(prompt_inputs)
673
674
675
676
            prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
                prompt_inputs["prompt_token_ids"],
                prompt_adapter_request=prompt_adapter_request,
            )
677
678
        elif (prompt_inputs["type"] == "embeds"):
            pass
679
        else:
680
            assert_never(prompt_inputs)  # type: ignore[arg-type]
681

682
        return prompt_inputs
683
684
685

    def _process_decoder_only_prompt(
        self,
686
        prompt: SingletonPrompt,
687
        tokenization_kwargs: Optional[dict[str, Any]] = None,
688
689
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
690
        return_mm_hashes: bool = False,
691
    ) -> DecoderOnlyInputs:
692
        """
693
        For decoder-only models:
694
        Process an input prompt into an :class:`DecoderOnlyInputs` instance.
695
696
697

        Arguments:

698
        * prompt: input prompt
699
700
        * lora_request
        * prompt_adapter_request
701
        * return_mm_hashes
702
703
704

        Returns:

705
        * :class:`DecoderOnlyInputs` instance
706
        """
707

708
        prompt_comps = self._prompt_to_llm_inputs(
709
            prompt,
710
            tokenization_kwargs=tokenization_kwargs,
711
            lora_request=lora_request,
712
            return_mm_hashes=return_mm_hashes,
713
714
715
716
717
718
719
720
721
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def _process_decoder_only_prompt_async(
        self,
722
        prompt: SingletonPrompt,
723
        tokenization_kwargs: Optional[dict[str, Any]] = None,
724
725
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
726
        return_mm_hashes: bool = False,
727
    ) -> DecoderOnlyInputs:
728
        """Async version of :meth:`_process_decoder_only_prompt`."""
729
        prompt_comps = await self._prompt_to_llm_inputs_async(
730
            prompt,
731
            tokenization_kwargs=tokenization_kwargs,
732
            lora_request=lora_request,
733
            return_mm_hashes=return_mm_hashes,
734
735
736
737
738
739
740
741
742
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    def preprocess(
        self,
743
        prompt: PromptType,
744
        tokenization_kwargs: Optional[dict[str, Any]] = None,
745
746
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
747
        return_mm_hashes: bool = False,
748
    ) -> ProcessorInputs:
749
        """Preprocess the input prompt."""
750
        if self.model_config.is_encoder_decoder:
751
752
753
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
754
755
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
756
            return self._process_encoder_decoder_prompt(prompt)
757

758
        if is_explicit_encoder_decoder_prompt(prompt):
759
760
761
762
763
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
764
            prompt,
765
            tokenization_kwargs=tokenization_kwargs,
766
767
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
768
            return_mm_hashes=return_mm_hashes,
769
770
771
772
        )

    async def preprocess_async(
        self,
773
        prompt: PromptType,
774
        tokenization_kwargs: Optional[dict[str, Any]] = None,
775
776
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
777
        return_mm_hashes: bool = False,
778
    ) -> ProcessorInputs:
779
        """Async version of :meth:`preprocess`."""
780
        if self.model_config.is_encoder_decoder:
781
782
783
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
784
785
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
786
            return await self._process_encoder_decoder_prompt_async(prompt)
787

788
        if is_explicit_encoder_decoder_prompt(prompt):
789
790
791
792
793
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
794
            prompt,
795
            tokenization_kwargs=tokenization_kwargs,
796
797
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
798
            return_mm_hashes=return_mm_hashes,
799
        )