"vscode:/vscode.git/clone" did not exist on "3ecabd06eee69e60c2239a6ca7159b978b26d6ce"
preprocess.py 28.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from typing import List, Mapping, Optional, Tuple, Union, cast
5
6
7
8
9
10

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
11
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
12
13
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                                    MultiModalInputs)
14
15
16
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

17
18
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
                   PromptType, SingletonInputs, SingletonPrompt, token_inputs)
19
20
21
22
23
24
25
26
27
28
29
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
        tokenizer: Optional[BaseTokenizerGroup],
30
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
31
32
33
34
35
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
36
        self.mm_registry = mm_registry
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def get_tokenizer_group(self) -> BaseTokenizerGroup:
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

    def get_bos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id

    def get_eos_token_id(self,
                         lora_request: Optional[LoRARequest] = None
                         ) -> Optional[int]:
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

    def get_decoder_start_token_id(self) -> Optional[int]:
        '''
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
        '''

72
        if not self.model_config.is_encoder_decoder:
73
74
75
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
76
77
78
            return None

        if (self.model_config is None or self.model_config.hf_config is None):
79
80
81
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
82
83
84
85
86
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
                                     'decoder_start_token_id', None)
        if dec_start_token_id is None:
87
88
89
90
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
        '''
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
117
        other models may have different or more
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
        '''

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
        decoder_input_ids: Optional[List[int]],
    ) -> List[int]:
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on

        https://github.com/huggingface/transformers/blob/
        4037a2b5b1278736e566aec12e169100275545ea/
        src/transformers/generation/utils.py

        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

163
164
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _apply_prompt_adapter(
        self,
        prompt_token_ids: List[int],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> List[int]:
        if prompt_adapter_request:
            prompt_token_ids = (
                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                + prompt_token_ids)

        return prompt_token_ids

    def _tokenize_prompt(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer_group()
192
193
194
195
196
197
        add_special_tokens = None
        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens = False
198
199
200
201
202
203

        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

204
205
        return tokenizer.encode(request_id=request_id,
                                prompt=prompt,
206
207
                                lora_request=lora_request,
                                add_special_tokens=add_special_tokens)
208
209
210
211
212
213
214
215
216

    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
        tokenizer = self.get_tokenizer_group()
217
218
219
220
221
222
223
224
225
226
227
        add_special_tokens = None
        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens = False
        return await tokenizer.encode_async(
            request_id=request_id,
            prompt=prompt,
            lora_request=lora_request,
            add_special_tokens=add_special_tokens)
228

229
230
231
232
233
234
235
236
237
238
    def _can_process_multimodal(self) -> bool:
        model_config = self.model_config

        if not model_config.is_multimodal_model:
            raise ValueError("Your model does not support multi-modal inputs")

        # Interim measure so we can handle models that have yet to be
        # updated to use the new multi-modal processor
        can_process_multimodal = self.mm_registry.has_processor(model_config)
        if not can_process_multimodal:
239
            logger.info_once(
240
241
242
243
244
245
246
247
248
249
250
251
252
                "Your model uses the legacy input pipeline instead of the new "
                "multi-modal processor. Please note that the legacy pipeline "
                "will be removed in a future release. For more details, see: "
                "https://github.com/vllm-project/vllm/issues/10114")

        return can_process_multimodal

    def _process_multimodal(
        self,
        prompt: Union[str, List[int]],
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
253
    ) -> MultiModalInputs:
254
255
256
257
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
258
259
260
261
262
263
264
265
        # At the moment on model (PrithviGeoSpatialMAE) requires to be
        # initialized without a tokenizer while using also multi-modal
        # input.
        if not self.tokenizer:
            tokenizer = None
        else:
            tokenizer_group = self.get_tokenizer_group()
            tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        mm_processor = self.mm_registry.create_processor(
            self.model_config, tokenizer)

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)

    async def _process_multimodal_async(
        self,
        prompt: Union[str, List[int]],
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
        lora_request: Optional[LoRARequest],
281
    ) -> MultiModalInputs:
282
        """Async version of :meth:`_process_multimodal`."""
283
284
285
286
287
288
289
290
291
        # At the moment on model (PrithviGeoSpatialMAE) requires to be
        # initialized without a tokenizer while using also multi-modal
        # input.
        if not self.tokenizer:
            tokenizer = None
        else:
            tokenizer_group = self.get_tokenizer_group()
            tokenizer = await tokenizer_group.get_lora_tokenizer_async(
                lora_request)
292
293
294
295
296
297
298
299

        mm_processor = self.mm_registry.create_processor(
            self.model_config, tokenizer)
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

        return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)

300
    def _prompt_to_llm_inputs(
301
        self,
302
        prompt: SingletonPrompt,
303
304
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
305
    ) -> SingletonInputs:
306
307
        """
        Extract the singleton inputs from a prompt.
308
309
310
311

        Arguments:

        * request_id
312
        * prompt: single encoder or decoder input prompt
313
314
315
316
        * lora_request: this is only valid for decoder prompts

        Returns:

317
318
        * :class:`SingletonInputs` instance
        """
319
        parsed = parse_singleton_prompt(prompt)
320
321

        if parsed["type"] == "str":
322
            prompt_text = parsed["content"]
323
            prompt_token_ids = self._tokenize_prompt(
324
                prompt_text,
325
326
327
                request_id=request_id,
                lora_request=lora_request,
            )
328
329
330
331
332
333
334
335
336
337

            return token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if parsed["type"] == "tokens":
            tokens_content = parsed["content"]

            prompt_token_ids = tokens_content["prompt_token_ids"]
338
            token_type_ids = tokens_content.get("token_type_ids")
339
340
341
            multi_modal_data = tokens_content.get("multi_modal_data")
            mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")

342
343
344
345
346
347
348
349
            if multi_modal_data is not None and self._can_process_multimodal():
                return self._process_multimodal(
                    prompt_token_ids,
                    multi_modal_data,
                    mm_processor_kwargs,
                    lora_request=lora_request,
                )

350
351
            return token_inputs(
                prompt_token_ids=prompt_token_ids,
352
                token_type_ids=token_type_ids,
353
354
355
356
357
358
359
360
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
            )

        if parsed["type"] == "text":
            text_content = parsed["content"]

            prompt_text = text_content["prompt"]
361
362
363
364
365
366
367
368
369
370
371
            multi_modal_data = text_content.get("multi_modal_data")
            mm_processor_kwargs = text_content.get("mm_processor_kwargs")

            if multi_modal_data is not None and self._can_process_multimodal():
                return self._process_multimodal(
                    prompt_text,
                    multi_modal_data,
                    mm_processor_kwargs,
                    lora_request=lora_request,
                )

372
            prompt_token_ids = self._tokenize_prompt(
373
                prompt_text,
374
375
376
                request_id=request_id,
                lora_request=lora_request,
            )
377
378
379
380
381
382
383

            return token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
            )
384

385
        assert_never(parsed)
386

387
    async def _prompt_to_llm_inputs_async(
388
        self,
389
        prompt: SingletonPrompt,
390
391
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
392
    ) -> SingletonInputs:
393
        """Async version of :meth:`_extract_prompt_components`."""
394
        parsed = parse_singleton_prompt(prompt)
395
396

        if parsed["type"] == "str":
397
            prompt_text = parsed["content"]
398
            prompt_token_ids = await self._tokenize_prompt_async(
399
                prompt_text,
400
401
402
                request_id=request_id,
                lora_request=lora_request,
            )
403
404
405
406
407
408
409
410
411
412
413
414
415

            return token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
            )

        if parsed["type"] == "tokens":
            tokens_content = parsed["content"]

            prompt_token_ids = tokens_content["prompt_token_ids"]
            multi_modal_data = tokens_content.get("multi_modal_data")
            mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")

416
417
418
419
420
421
422
423
            if multi_modal_data is not None and self._can_process_multimodal():
                return await self._process_multimodal_async(
                    prompt_token_ids,
                    multi_modal_data,
                    mm_processor_kwargs,
                    lora_request=lora_request,
                )

424
425
426
427
428
429
430
431
432
433
            return token_inputs(
                prompt_token_ids=prompt_token_ids,
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
            )

        if parsed["type"] == "text":
            text_content = parsed["content"]

            prompt_text = text_content["prompt"]
434
435
436
437
438
439
440
441
442
443
444
            multi_modal_data = text_content.get("multi_modal_data")
            mm_processor_kwargs = text_content.get("mm_processor_kwargs")

            if multi_modal_data is not None and self._can_process_multimodal():
                return await self._process_multimodal_async(
                    prompt_text,
                    multi_modal_data,
                    mm_processor_kwargs,
                    lora_request=lora_request,
                )

445
            prompt_token_ids = await self._tokenize_prompt_async(
446
                prompt_text,
447
448
449
                request_id=request_id,
                lora_request=lora_request,
            )
450
451
452
453
454
455
456

            return token_inputs(
                prompt=prompt_text,
                prompt_token_ids=prompt_token_ids,
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
            )
457

458
        assert_never(parsed)
459
460
461

    def _build_enc_dec_llm_inputs(
        self,
462
463
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
464
    ) -> EncoderDecoderInputs:
465
466
        if (encoder_inputs["type"] == "token"
                or encoder_inputs["type"] == "multimodal"):
467
468
            pass
        else:
469
            assert_never(encoder_inputs)  # type: ignore[arg-type]
470
471

        if decoder_inputs is None:
472
473
474
475
476
477
478
479
480
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
481
            decoder_inputs = token_inputs(dec_token_ids)
482
483
        elif (decoder_inputs["type"] == "token"
              or decoder_inputs["type"] == "multimodal"):
484
485
486
487
488
489
490
491
            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids

            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
        else:
492
            assert_never(encoder_inputs)  # type: ignore[arg-type]
493

494
        return EncoderDecoderInputs(
495
496
            encoder=encoder_inputs,
            decoder=decoder_inputs,
497
498
        )

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    def _separate_enc_dec_inputs_from_mm_processor_outputs(
        self,
        inputs: SingletonInputs,
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
    ) -> Tuple[SingletonInputs, SingletonInputs]:
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
        if inputs["type"] == "multimodal":
            # Multimodal data inputs
            assert ("encoder_prompt" in inputs
                    and "encoder_prompt_token_ids" in inputs)
            inputs = cast(MultiModalEncDecInputs, inputs)
            encoder_inputs = token_inputs(
                prompt=inputs["encoder_prompt"],
                prompt_token_ids=inputs["encoder_prompt_token_ids"],
            )
            if decoder_inputs_to_override is not None:
                decoder_inputs = MultiModalInputs(
                    type="multimodal",
                    prompt=decoder_inputs_to_override.get("prompt", ""),
                    prompt_token_ids=decoder_inputs_to_override[
                        "prompt_token_ids"],
                    mm_kwargs=inputs["mm_kwargs"],
                    mm_placeholders=inputs["mm_placeholders"],
                )
            else:
                decoder_inputs = MultiModalInputs(
                    type="multimodal",
                    prompt=inputs["prompt"],
                    prompt_token_ids=inputs["prompt_token_ids"],
                    mm_kwargs=inputs["mm_kwargs"],
                    mm_placeholders=inputs["mm_placeholders"],
                )
        elif inputs["type"] == "token":
            # Text-only inputs
            encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
        return encoder_inputs, decoder_inputs

544
545
    def _process_encoder_decoder_prompt(
        self,
546
        prompt: PromptType,
547
        request_id: str,
548
    ) -> EncoderDecoderInputs:
549
        """
550
        For encoder/decoder models only:
551
        Process an input prompt into an :class:`EncoderDecoderInputs` instance.
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
570

571
572
        Arguments:

573
        * prompt: an input prompt
574
575
576
577
        * request_id

        Returns:

578
        * :class:`EncoderDecoderInputs` instance
579
        """
580
581
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
582

583
        if is_explicit_encoder_decoder_prompt(prompt):
584
            encoder_inputs = self._prompt_to_llm_inputs(
585
                prompt["encoder_prompt"],
586
587
                request_id=request_id,
            )
588
            if (decoder_input := prompt["decoder_prompt"]) is None:
589
                decoder_inputs = None
590
            else:
591
                decoder_inputs = self._prompt_to_llm_inputs(
592
593
594
                    decoder_input,
                    request_id=request_id,
                )
595
596
597
598
599
600
601
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
            if self.model_config.is_multimodal_model and (
                    self._can_process_multimodal()):
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        encoder_inputs, decoder_inputs))
602
        else:
603
            inputs = self._prompt_to_llm_inputs(
604
                prompt,
605
606
                request_id=request_id,
            )
607
608
609
610
611
612
613
614
            if self.model_config.is_multimodal_model and (
                    self._can_process_multimodal()):
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        inputs))
            else:
                encoder_inputs = inputs
615

616
                decoder_inputs = None
617
618

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
619
620
621

    async def _process_encoder_decoder_prompt_async(
        self,
622
        prompt: PromptType,
623
        request_id: str,
624
    ) -> EncoderDecoderInputs:
625
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
626
627
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
628

629
        if is_explicit_encoder_decoder_prompt(prompt):
630
            encoder_task = self._prompt_to_llm_inputs_async(
631
                prompt["encoder_prompt"],
632
633
634
                request_id=request_id,
            )

635
            if (decoder_input := prompt["decoder_prompt"]) is None:
636
637
                encoder_inputs = await encoder_task
                decoder_inputs = None
638
            else:
639
                decoder_task = self._prompt_to_llm_inputs_async(
640
641
642
643
                    decoder_input,
                    request_id=request_id,
                )

644
                encoder_inputs, decoder_inputs = await asyncio.gather(
645
                    encoder_task, decoder_task)
646
647
648
649
650
651
652
653

            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
            if self.model_config.is_multimodal_model and (
                    self._can_process_multimodal()):
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        encoder_inputs, decoder_inputs))
654
        else:
655
            inputs = await self._prompt_to_llm_inputs_async(
656
                prompt,
657
658
                request_id=request_id,
            )
659
660
661
662
663
664
665
666
            if self.model_config.is_multimodal_model and (
                    self._can_process_multimodal()):
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
                    self._separate_enc_dec_inputs_from_mm_processor_outputs(
                        inputs))
            else:
                encoder_inputs = inputs
667

668
                decoder_inputs = None
669
670

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
671
672
673

    def _build_decoder_only_llm_inputs(
        self,
674
        prompt_inputs: DecoderOnlyInputs,
675
        prompt_adapter_request: Optional[PromptAdapterRequest],
676
    ) -> DecoderOnlyInputs:
677
678
        if (prompt_inputs["type"] == "token"
                or prompt_inputs["type"] == "multimodal"):
679
680
681
682
683
            prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
                prompt_inputs["prompt_token_ids"],
                prompt_adapter_request=prompt_adapter_request,
            )
        else:
684
            assert_never(prompt_inputs)  # type: ignore[arg-type]
685

686
        return prompt_inputs
687
688
689

    def _process_decoder_only_prompt(
        self,
690
        prompt: SingletonPrompt,
691
692
693
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
694
    ) -> DecoderOnlyInputs:
695
        """
696
        For decoder-only models:
697
        Process an input prompt into an :class:`DecoderOnlyInputs` instance.
698
699
700

        Arguments:

701
        * prompt: input prompt
702
703
704
705
706
707
        * request_id
        * lora_request
        * prompt_adapter_request

        Returns:

708
        * :class:`DecoderOnlyInputs` instance
709
        """
710

711
        prompt_comps = self._prompt_to_llm_inputs(
712
            prompt,
713
714
715
716
717
718
719
720
721
722
723
            request_id=request_id,
            lora_request=lora_request,
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def _process_decoder_only_prompt_async(
        self,
724
        prompt: SingletonPrompt,
725
726
727
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
728
    ) -> DecoderOnlyInputs:
729
        """Async version of :meth:`_process_decoder_only_prompt`."""
730
        prompt_comps = await self._prompt_to_llm_inputs_async(
731
            prompt,
732
733
734
735
736
737
738
739
740
741
742
            request_id=request_id,
            lora_request=lora_request,
        )

        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )

    def preprocess(
        self,
743
        prompt: PromptType,
744
745
746
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
747
    ) -> ProcessorInputs:
748
        """Preprocess the input prompt."""
749
        if self.model_config.is_encoder_decoder:
750
751
752
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            return self._process_encoder_decoder_prompt(
753
                prompt,
754
755
756
                request_id=request_id,
            )

757
        if is_explicit_encoder_decoder_prompt(prompt):
758
759
760
761
762
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return self._process_decoder_only_prompt(
763
            prompt,
764
765
766
767
768
769
770
            request_id=request_id,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

    async def preprocess_async(
        self,
771
        prompt: PromptType,
772
773
774
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
775
    ) -> ProcessorInputs:
776
        """Async version of :meth:`preprocess`."""
777
        if self.model_config.is_encoder_decoder:
778
779
780
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            return await self._process_encoder_decoder_prompt_async(
781
                prompt,
782
783
784
                request_id=request_id,
            )

785
        if is_explicit_encoder_decoder_prompt(prompt):
786
787
788
789
790
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
        return await self._process_decoder_only_prompt_async(
791
            prompt,
792
793
794
795
            request_id=request_id,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )