preprocess.py 32.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from collections.abc import Mapping
5
from typing import Any, Optional, Union, cast
6
7
8
9
10
11

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
14
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                                    MultiModalInputs)
15
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
18

19
20
21
22
23
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
                   EncoderDecoderInputs, ProcessorInputs, PromptType,
                   SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
                   TokensPrompt, embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
24
25
26
27
28
29
30
31
32

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
33
        tokenizer: Optional[TokenizerGroup],
34
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
35
36
37
38
39
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
40
        self.mm_registry = mm_registry
41

42
    def get_tokenizer_group(self) -> TokenizerGroup:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

76
        if not self.model_config.is_encoder_decoder:
77
78
79
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
80
81
82
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
83
84
85
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
86
87
88
89
90
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
91
92
93
94
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
95
96
97
98
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

99
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
121
        other models may have different or more
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
137
138
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
139
140
141
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

142
143
144
145
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

164
165
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
166
167
168
169
170
171
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _apply_prompt_adapter(
        self,
172
        prompt_token_ids: list[int],
173
        prompt_adapter_request: Optional[PromptAdapterRequest],
174
    ) -> list[int]:
175
176
177
178
179
180
181
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)

        return prompt_token_ids

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

199
200
201
202
    def _tokenize_prompt(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
203
        tokenization_kwargs: Optional[dict[str, Any]] = None,
204
    ) -> list[int]:
205
206
207
208
209
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
210
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
211

212
        encoder_config = self.model_config.encoder_config
213

214
        if encoder_config and encoder_config.get("do_lower_case", False):
215
216
            prompt = prompt.lower()

217
        return tokenizer.encode(prompt=prompt,
218
                                lora_request=lora_request,
219
                                **tokenization_kwargs)
220
221
222
223
224

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        lora_request: Optional[LoRARequest],
225
        tokenization_kwargs: Optional[dict[str, Any]] = None,
226
    ) -> list[int]:
227
        """Async version of {meth}`_tokenize_prompt`."""
228
        tokenizer = self.get_tokenizer_group()
229
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
230
231
232
233

        return await tokenizer.encode_async(prompt=prompt,
                                            lora_request=lora_request,
                                            **tokenization_kwargs)
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    def _get_mm_tokenizer(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return tokenizer_group.get_lora_tokenizer(lora_request)

    async def _get_mm_tokenizer_async(
        self,
        lora_request: Optional[LoRARequest],
    ) -> AnyTokenizer:
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

        tokenizer_group = self.get_tokenizer_group()
        return await tokenizer_group.get_lora_tokenizer_async(lora_request)

259
260
    def _process_multimodal(
        self,
261
        prompt: Union[str, list[int]],
262
263
264
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
265
        return_mm_hashes: bool = False,
266
    ) -> MultiModalInputs:
267
268
269
270
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
271
        tokenizer = self._get_mm_tokenizer(lora_request)
272

273
274
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
275
276
277
278

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

279
280
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
281
282
283

    async def _process_multimodal_async(
        self,
284
        prompt: Union[str, list[int]],
285
286
287
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
288
        return_mm_hashes: bool = False,
289
    ) -> MultiModalInputs:
290
        """Async version of {meth}`_process_multimodal`."""
291
        tokenizer = await self._get_mm_tokenizer_async(lora_request)
292

293
294
        mm_processor = self.mm_registry.create_processor(self.model_config,
                                                         tokenizer=tokenizer)
295
296
297
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

298
299
        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
                                  return_mm_hashes)
300

301
302
303
304
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
305
306
307
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
308
309

        prompt_embeds = parsed_content["prompt_embeds"]
310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")

        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))

    async def _process_embeds_async(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        return self._process_embeds(parsed_content)

    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
349
        else:
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            inputs = token_inputs(
                prompt_token_ids=prompt_token_ids,
                token_type_ids=token_type_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    async def _process_tokens_async(
        self,
        parsed_content: TokensPrompt,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_token_ids = parsed_content["prompt_token_ids"]
        token_type_ids = parsed_content.get("token_type_ids")

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            inputs = token_inputs(
                prompt_token_ids=prompt_token_ids,
                token_type_ids=token_type_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
422

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    async def _process_text_async(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
        lora_request: Optional[LoRARequest] = None,
        return_mm_hashes: bool = False,
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = await self._process_multimodal_async(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs"),
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
            )
        else:
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt_text,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
456

457
    def _prompt_to_llm_inputs(
458
        self,
459
        prompt: SingletonPrompt,
460
        tokenization_kwargs: Optional[dict[str, Any]] = None,
461
        lora_request: Optional[LoRARequest] = None,
462
        return_mm_hashes: bool = False,
463
    ) -> SingletonInputs:
464
465
        """
        Extract the singleton inputs from a prompt.
466
467
468

        Arguments:

469
        * prompt: single encoder or decoder input prompt
470
        * lora_request: this is only valid for decoder prompts
471
        * return_mm_hashes: whether to return multimodal hashes
472
473
474

        Returns:

475
        * {class}`SingletonInputs` instance
476
        """
477
        parsed = parse_singleton_prompt(prompt)
478
479

        if parsed["type"] == "embeds":
480
481
482
483
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
484
                lora_request=lora_request,
485
                return_mm_hashes=return_mm_hashes,
486
            )
487
488
489
490
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
491
                lora_request=lora_request,
492
493
494
495
496
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
497
                tokenization_kwargs=tokenization_kwargs,
498
499
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
500
            )
501

502
503
        assert_never(parsed)

504
    async def _prompt_to_llm_inputs_async(
505
        self,
506
        prompt: SingletonPrompt,
507
        tokenization_kwargs: Optional[dict[str, Any]] = None,
508
        lora_request: Optional[LoRARequest] = None,
509
        return_mm_hashes: bool = False,
510
    ) -> SingletonInputs:
511
        """Async version of {meth}`_prompt_to_llm_inputs`."""
512
        parsed = parse_singleton_prompt(prompt)
513

514
        if parsed["type"] == "embeds":
515
516
517
518
            return await self._process_embeds_async(parsed["content"])
        if parsed["type"] == "tokens":
            return await self._process_tokens_async(
                parsed["content"],
519
520
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
521
            )
522
523
524
525
        if parsed["type"] == "text":
            return await self._process_text_async(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
526
                lora_request=lora_request,
527
528
529
530
531
                return_mm_hashes=return_mm_hashes,
            )
        if parsed["type"] == "str":
            return await self._process_text_async(
                TextPrompt(prompt=parsed["content"]),
532
                tokenization_kwargs=tokenization_kwargs,
533
534
                lora_request=lora_request,
                return_mm_hashes=return_mm_hashes,
535
            )
536

537
538
        assert_never(parsed)

539
540
    def _build_enc_dec_llm_inputs(
        self,
541
542
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
543
    ) -> EncoderDecoderInputs:
544
545
546
547
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
548

549
550
551
552
553
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
554

555
        if decoder_inputs is None:
556
557
558
559
560
561
562
563
564
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
565
            decoder_inputs = token_inputs(dec_token_ids)
566
        else:
567
568
569
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
570
571
572
573

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
574

575
        return EncoderDecoderInputs(
576
577
            encoder=encoder_inputs,
            decoder=decoder_inputs,
578
579
        )

580
    def _split_enc_dec_mm_inputs(
581
        self,
582
583
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
584
    ) -> tuple[SingletonInputs, SingletonInputs]:
585
586
587
588
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

604
605
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
606
607
608
609
610
611
612

        if inputs["type"] == "multimodal":  # Multimodal data inputs
            if not ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs):
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
613
            inputs = cast(MultiModalEncDecInputs, inputs)
614

615
616
617
618
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
619

620
621
622
623
624
625
626
627
628
629
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt=decoder_prompt_inputs.get("prompt", ""),
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
630
631
                decoder_inputs["cache_salt"] = cache_salt

632
        elif inputs["type"] == "token":  # Text-only inputs
633
634
635
636
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
637

638
639
        return encoder_inputs, decoder_inputs

640
641
    def _process_encoder_decoder_prompt(
        self,
642
        prompt: PromptType,
643
        tokenization_kwargs: Optional[dict[str, Any]] = None,
644
    ) -> EncoderDecoderInputs:
645
        """
646
        For encoder/decoder models only:
647
        Process an input prompt into an {class}`EncoderDecoderInputs` instance.
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
666

667
668
        Arguments:

669
        * prompt: an input prompt
670
671
672

        Returns:

673
        * {class}`EncoderDecoderInputs` instance
674
        """
675
676
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
677

678
        if is_explicit_encoder_decoder_prompt(prompt):
679
            encoder_inputs = self._prompt_to_llm_inputs(
680
681
682
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
683
            if (decoder_input := prompt["decoder_prompt"]) is None:
684
                decoder_inputs = None
685
            else:
686
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
687
688
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
689
            if self.model_config.is_multimodal_model:
690
                encoder_inputs, decoder_inputs = (
691
692
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
693
        else:
694
695
696
697
            inputs = self._prompt_to_llm_inputs(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
698
            if self.model_config.is_multimodal_model:
699
700
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
701
                    self._split_enc_dec_mm_inputs(inputs))
702
703
704
            else:
                encoder_inputs = inputs
                decoder_inputs = None
705
706

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
707
708
709

    async def _process_encoder_decoder_prompt_async(
        self,
710
        prompt: PromptType,
711
        tokenization_kwargs: Optional[dict[str, Any]] = None,
712
    ) -> EncoderDecoderInputs:
713
        """Async version of {meth}`_process_encoder_decoder_prompt`."""
714
715
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
716

717
        if is_explicit_encoder_decoder_prompt(prompt):
718
            encoder_task = self._prompt_to_llm_inputs_async(
719
720
721
                prompt["encoder_prompt"],
                tokenization_kwargs=tokenization_kwargs,
            )
722

723
            if (decoder_input := prompt["decoder_prompt"]) is None:
724
725
                encoder_inputs = await encoder_task
                decoder_inputs = None
726
            else:
727
728
729
730
                decoder_task = self._prompt_to_llm_inputs_async(
                    decoder_input,
                    tokenization_kwargs=tokenization_kwargs,
                )
731

732
                encoder_inputs, decoder_inputs = await asyncio.gather(
733
                    encoder_task, decoder_task)
734
735
736

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
737
            if self.model_config.is_multimodal_model:
738
                encoder_inputs, decoder_inputs = (
739
740
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
741
        else:
742
743
744
745
            inputs = await self._prompt_to_llm_inputs_async(
                prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
746
            if self.model_config.is_multimodal_model:
747
748
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
749
                    self._split_enc_dec_mm_inputs(inputs))
750
751
752
            else:
                encoder_inputs = inputs
                decoder_inputs = None
753
754

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
755
756
757

    def _build_decoder_only_llm_inputs(
        self,
758
        prompt_inputs: DecoderOnlyInputs,
759
        prompt_adapter_request: Optional[PromptAdapterRequest],
760
    ) -> DecoderOnlyInputs:
761
762
763
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
764
765
766
767
            prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
                prompt_inputs["prompt_token_ids"],
                prompt_adapter_request=prompt_adapter_request,
            )
768

769
        return prompt_inputs
770
771
772

    def _process_decoder_only_prompt(
        self,
773
        prompt: SingletonPrompt,
774
        tokenization_kwargs: Optional[dict[str, Any]] = None,
775
776
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
777
        return_mm_hashes: bool = False,
778
    ) -> DecoderOnlyInputs:
779
        """
780
        For decoder-only models:
781
        Process an input prompt into an {class}`DecoderOnlyInputs` instance.
782
783
784

        Arguments:

785
        * prompt: input prompt
786
787
        * lora_request
        * prompt_adapter_request
788
        * return_mm_hashes
789
790
791

        Returns:

792
        * {class}`DecoderOnlyInputs` instance
793
        """
794

795
        prompt_comps = self._prompt_to_llm_inputs(
796
            prompt,
797
            tokenization_kwargs=tokenization_kwargs,
798
            lora_request=lora_request,
799
            return_mm_hashes=return_mm_hashes,
800
801
802
803
804
805
806
807
808
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def _process_decoder_only_prompt_async(
        self,
809
        prompt: SingletonPrompt,
810
        tokenization_kwargs: Optional[dict[str, Any]] = None,
811
812
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
813
        return_mm_hashes: bool = False,
814
    ) -> DecoderOnlyInputs:
815
        """Async version of {meth}`_process_decoder_only_prompt`."""
816
        prompt_comps = await self._prompt_to_llm_inputs_async(
817
            prompt,
818
            tokenization_kwargs=tokenization_kwargs,
819
            lora_request=lora_request,
820
            return_mm_hashes=return_mm_hashes,
821
822
823
824
825
826
827
828
829
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    def preprocess(
        self,
830
        prompt: PromptType,
831
        tokenization_kwargs: Optional[dict[str, Any]] = None,
832
833
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
834
        return_mm_hashes: bool = False,
835
    ) -> ProcessorInputs:
836
        """Preprocess the input prompt."""
837
        if self.model_config.is_encoder_decoder:
838
839
840
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
841
842
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
843
            return self._process_encoder_decoder_prompt(prompt)
844

845
        if is_explicit_encoder_decoder_prompt(prompt):
846
847
848
849
850
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
851
            prompt,
852
            tokenization_kwargs=tokenization_kwargs,
853
854
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
855
            return_mm_hashes=return_mm_hashes,
856
857
858
859
        )

    async def preprocess_async(
        self,
860
        prompt: PromptType,
861
        tokenization_kwargs: Optional[dict[str, Any]] = None,
862
863
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
864
        return_mm_hashes: bool = False,
865
    ) -> ProcessorInputs:
866
        """Async version of {meth}`preprocess`."""
867
        if self.model_config.is_encoder_decoder:
868
869
870
            assert not return_mm_hashes, (
                "Multimodal hashes for encoder-decoder models should not be ",
                "returned until they are supported on vLLM V1.")
871
872
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
873
            return await self._process_encoder_decoder_prompt_async(prompt)
874

875
        if is_explicit_encoder_decoder_prompt(prompt):
876
877
878
879
880
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
881
            prompt,
882
            tokenization_kwargs=tokenization_kwargs,
883
884
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
885
            return_mm_hashes=return_mm_hashes,
886
        )