processing.py 47.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import re
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from itertools import groupby
12
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
13
                    TypeVar, Union, cast)
14

15
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
16
from typing_extensions import assert_never
17

18
19
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
20
from vllm.logger import init_logger
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
23
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
24

25
from .hasher import MultiModalHasher
26
27
28
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32
33

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
34

35
logger = init_logger(__name__)
36
37

_S = TypeVar("_S", str, list[int])
38
39
40

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
41

42

43
@dataclass
44
45
class PromptUpdateDetails:
    """Details about the token sequence or text that are part of the update."""
46
47

    full: PromptSeq
48
    """The full content."""
49

50
    features: PromptSeq
51
    """
52
    The part of the content that corresponds to feature placeholders;
53
54
    this will be replaced by the output of the vision encoder during model
    inference.
55
56
57
    """

    @staticmethod
58
59
    def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
        return PromptUpdateDetails(full=seq, features=seq)
60
61


62
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
63
"""
64
The token sequence or text that are part of the update.
65

66
67
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
68
"""
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
class PromptUpdate:
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

    target: PromptSeq
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

116

117
@dataclass
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder at the start of the
        prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target="",
                insertion="<image>" * image_feature_size,
            )

        As above, but insert after the ``<s>`` token:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target="<s>",
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
167
168
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
193
                replacement=PromptUpdateDetails(
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
211
                replacement=PromptUpdateDetails(
212
213
214
215
216
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
217
218
    """

219
    replacement: PromptUpdateContent = field(repr=False)
220
    """
221
    Given the index of the processed item within :attr:`modality`,
222
    output the token sequence (or text) to replace :attr:`target`.
223

224
225
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
226
227
    """

228
229
230
231
232
233
234
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
235
236


237
238
239
240
241
242
243
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
244
245
246
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
247
248


249
250
251
252
253
254
255
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
256
257
258
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
259
260
261
262
263


class _HasModalityAttr(Protocol):
    modality: str

264

265
class _HasModalityProp(Protocol):
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
282
283
284
285
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
286
287
    tokenizer: AnyTokenizer = field(repr=False)

288
289
290
    _text: Optional[str]
    _token_ids: Optional[list[int]]

291
    @staticmethod
292
293
    def from_seq(
        tokenizer: AnyTokenizer,
294
        seq: PromptSeq,
295
    ) -> "_BoundPromptSequence":
296
297
298
299
300
301
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


324
@dataclass
325
class _BoundPromptContent:
326
327
328
329
    full: _BoundPromptSequence
    features: _BoundPromptSequence


330
@dataclass
331
class BoundPromptUpdate:
332
    """
333
334
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
335
336
    token sequence and text representations.
    """
337
    _origin: PromptUpdate
338
    tokenizer: AnyTokenizer = field(repr=False)
339

340
    def __post_init__(self) -> None:
341
342
343
344
345
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
346
347
348

    @property
    def target(self) -> _BoundPromptSequence:
349
350
351
        """The token sequence (or text) to update."""
        return _BoundPromptSequence.from_seq(self.tokenizer,
                                             self._origin.target)
352

353
354
355
356
357
358
359
360
361
362
363
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
364
365
        """
        Given the index of the processed item within :attr:`modality`,
366
        output the token sequence (or text) to update.
367
        """
368
369
        content = self.content
        if callable(content):
370
            cache_key = item_idx
371
372
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
373

374
            content = content(item_idx)
375
376
377
        else:
            cache_key = None

378
379
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
380
381

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
382
                                                   content.full)
383
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
384
385
386
                                                       content.features)
        bound_content = _BoundPromptContent(full=bound_full,
                                            features=bound_features)
387
388

        if cache_key is not None:
389
            self._content_cache[cache_key] = bound_content
390

391
        return bound_content
392
393


394
395
396
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
397
398


399
400
401
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
402
) -> Generator[_TokenMatch]:
403
404
405
406
407
408
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
409
    match_len = len(match_ids)
410

411
412
    if match_len == 0:
        return
413

414
415
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
416
        end_idx = start_idx + match_len
417

418
419
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
420
421
422
423
424

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
425
426


427
@dataclass(repr=False)
428
429
class _PromptTargetMatch(ABC):
    _origin: BoundPromptUpdate
430
431
432

    @property
    def modality(self) -> str:
433
        return self._origin.modality
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
451
class _PromptTargetTokenMatch(_PromptTargetMatch):
452
453
454
455
456
457
458
459
460
461
462
463
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
464
class _PromptTargetTextMatch(_PromptTargetMatch):
465
466
467
468
469
470
471
472
473
474
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

475

476
@dataclass
477
class PlaceholderFeaturesInfo:
478
    modality: str
479
    item_idx: int
480
    start_idx: int
481
    tokens: list[int]
482
483
484

    @property
    def length(self) -> int:
485
        return len(self.tokens)
486
487
488
489
490
491

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
492
493
494
495


def find_token_matches(
    prompt: list[int],
496
497
498
    prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
499
    return [
500
501
        _PromptTargetTokenMatch(update, match) for update in prompt_updates
        for match in iter_token_matches(prompt, update.target.token_ids)
502
503
504
505
506
    ]


def find_text_matches(
    prompt: str,
507
508
509
    prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
510
    return [
511
512
        _PromptTargetTextMatch(update, match) for update in prompt_updates
        for match in re.finditer(re.escape(update.target.text), prompt)
513
514
515
516
    ]


def _resolve_matches(
517
    prompt: PromptSeq,
518
519
    mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
) -> list[_PromptTargetMatch]:
520
    """
521
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
522
    and sort them such that earlier matches take priority over later ones.
523
    """
524
525
    matches = [m for matches in mm_matches.values() for m in matches]

526
    seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
527

528
    for match in matches:
529
530
531
532
533
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
534

535
            seen_matches[idx] = match
536
537
538
539

    return sorted(matches, key=lambda x: x.start_idx)


540
def _apply_matches(
541
    prompt: _S,
542
    mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
543
    mm_item_counts: Mapping[str, int],
544
) -> list[_S]:
545
546
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
547
    prev_end_idx = 0
548
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
549

550
551
552
553
554
555
    for (start_idx, end_idx), group in groupby(
            _resolve_matches(prompt, mm_matches),
            key=lambda x: (x.start_idx, x.end_idx),
    ):
        matches = tuple(group)
        assert len(matches) == 1
556

557
558
        for match in matches:
            modality = match.modality
559

560
561
562
            item_idx = next_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
                continue
563

564
565
566
            origin = match._origin
            content = origin.get_content(item_idx)
            mode = origin.mode
567

568
569
570
571
572
573
574
575
            if mode == UpdateMode.INSERT:
                out_seqs.append(prompt[prev_end_idx:end_idx])
                num_inserts = mm_item_counts.get(modality, 0)
            elif mode == UpdateMode.REPLACE:
                out_seqs.append(prompt[prev_end_idx:start_idx])
                num_inserts = 1
            else:
                assert_never(mode)
576

577
578
579
580
581
582
583
584
585
586
587
588
            for _ in range(num_inserts):
                if item_idx >= mm_item_counts.get(modality, 0):
                    continue

                if isinstance(prompt, str):
                    out_seqs.append(content.full.text)
                else:
                    out_seqs.append(content.full.token_ids)

                next_idx_by_modality[modality] += 1

            prev_end_idx = end_idx
589
590
591

    out_seqs.append(prompt[prev_end_idx:])

592
    return cast(list[_S], out_seqs)
593
594


595
def apply_token_matches(
596
    prompt: list[int],
597
    mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
598
    mm_item_counts: Mapping[str, int],
599
) -> list[int]:
600
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
601
    if not mm_matches:
602
603
        return prompt

604
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
605
606

    return flatten_2d_lists(token_id_seqs)
607
608


609
def apply_text_matches(
610
    prompt: str,
611
    mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
612
    mm_item_counts: Mapping[str, int],
613
) -> str:
614
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
615
    if not mm_matches:
616
        return prompt
617

618
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
619
620

    return "".join(texts)
621
622


623
def _iter_placeholders(
624
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
625
    prompt: list[int],
626
    mm_item_counts: Mapping[str, int],
627
) -> Iterable[PlaceholderFeaturesInfo]:
628
629
630
631
632
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
633
    appears earlier in `mm_prompt_updates` takes priority.
634

635
636
    Note that empty matches are ignored.
    """
637
    prompt_len = len(prompt)
638
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
639
640
641
642
643

    start_idx = 0
    while start_idx < prompt_len:
        found = False

644
        for modality, modality_updates in mm_prompt_updates.items():
645
646
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
647
                continue
648

649
650
651
652
653
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
654

655
                if content_len_full == 0 or end_idx_full > prompt_len:
656
657
                    continue

658
659
                if prompt[start_idx:end_idx_full] == content_tokens_full:
                    content_tokens_feat = content.features.token_ids
660
661
662

                    try:
                        match = next(
663
664
                            iter_token_matches(content_tokens_full,
                                               content_tokens_feat))
665
666
667
668
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
669
                            tokens=content_tokens_feat,
670
671
672
                        )
                    except StopIteration:
                        raise AssertionError(
673
674
                            f"{content_tokens_feat=} should be a "
                            f"subsequence of {content_tokens_full=}") from None
675

676
                    # Exclude overlapping matches
677
                    start_idx = end_idx_full
678
679
680
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
681

682
683
            if found:
                break  # Go back to the outer while loop
684
685
686

        if not found:
            start_idx += 1
687
688


689
def find_mm_placeholders(
690
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
691
692
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
693
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
694
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
695
696
697
    return dict(full_groupby_modality(it))


698
699
700
701
702
703
704
705
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

706
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
724
    ) -> Optional[MultiModalKwargsItem]:
725
726
727
728
729
730
731
732
733
734
735
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

736
737
738
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
739
740
741
742
743
744
745
746
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
747
        output_kwargs: MultiModalKwargsItem,
748
749
750
751
752
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
753
754
755
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
756
        self._cache.put(cache_key, output_kwargs)
757
758


759
class BaseProcessingInfo:
760
    """Base class to provide the information necessary for data processing."""
761

762
763
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
764

765
766
767
768
769
770
771
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
772
773
        return self.ctx.tokenizer

774
    def get_hf_config(self) -> PretrainedConfig:
775
776
        return self.ctx.get_hf_config()

777
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
778
779
780
781
782
783
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

784
785
786
787
788
789
790
791
792
793
794
795
796
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
797
798
799
800
801
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
802
803
804
805
806
807
808
809
810
811
812
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
813

814
815

class BaseMultiModalProcessor(ABC, Generic[_I]):
816
    """
817
    Abstract base class to process multi-modal inputs to be used in vLLM.
818
819

    Not to be confused with :class:`transformers.ProcessorMixin`.
820
821
    """

822
    def __init__(self,
823
824
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
825
826
827
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
828
829
830
831
832
833
        if get_repls := getattr(self, "_get_prompt_replacements", None):
            logger.warning_once("`_get_prompt_replacements` has been renamed "
                                "to `_get_prompt_updates`. The old name will "
                                "be removed in an upcoming release.")
            self._get_prompt_updates = get_repls  # type: ignore[method-assign]

834
835
        super().__init__()

836
837
        self.info = info
        self.dummy_inputs = dummy_inputs
838
839
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
840

841
842
        self.data_parser = self._get_data_parser()

843
    def __call__(
844
        self,
845
846
        prompt: str,
        mm_data: MultiModalDataDict,
847
        hf_processor_mm_kwargs: Mapping[str, object],
848
    ) -> MultiModalInputs:
849
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
850

851
852
    def _get_data_parser(self) -> MultiModalDataParser:
        """
853
        Construct a parser to preprocess multi-modal data items
854
855
856
857
858
859
860
861
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
862
863
864
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
865
866
867
868
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
869
        mm_items = self.data_parser.parse_mm_data(mm_data)
870

871
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
872
873
874
875
876
877
878
879
880
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
881

882
883
884
885
886
887
888
889
890
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

891
    @abstractmethod
892
    def _get_prompt_updates(
893
        self,
894
        mm_items: MultiModalDataItems,
895
896
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
897
    ) -> list[PromptUpdate]:
898
899
        """
        Given the original multi-modal items for this modality
900
        and HF-processed data, output the updates to perform.
901

902
903
        Notes:
            - You should not assume that HF processor always performs prompt
904
              updates: in :meth:`_apply_hf_processor_missing`, this method
905
906
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
907
908
            - The update information returned by this method is also used to
              determine the placeholder token positions for each multi-modal
909
              item.
910
911
        """
        raise NotImplementedError
912

913
    def _find_mm_placeholders(
914
        self,
915
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
916
        new_token_ids: list[int],
917
        mm_item_counts: Mapping[str, int],
918
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
919
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
920
                                    mm_item_counts)
921

922
    def _get_hf_mm_data(
923
        self,
924
        mm_items: MultiModalDataItems,
925
926
927
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
928

929
930
931
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
932

933
934
        return processor_data, passthrough_data

935
936
937
    def _call_hf_processor(
        self,
        prompt: str,
938
939
940
941
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
942
    ) -> BatchFeature:
943
944
945
946
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
947
948
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
949
950
            dict(text=prompt, **mm_data),
            mm_kwargs,
951
952
        )

953
    def _hf_processor_applies_updates(
954
955
956
957
958
959
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
960
        Return whether the HF processor applies prompt updates.
961
962
963
964
965
966
967
968
969

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

970
    def _apply_hf_processor_text_mm(
971
        self,
972
        prompt_text: str,
973
        mm_items: MultiModalDataItems,
974
        hf_processor_mm_kwargs: Mapping[str, object],
975
    ) -> tuple[list[int], MultiModalKwargs, bool]:
976
        """
977
978
        Apply the HF processor on the prompt text and multi-modal data
        together.
979

980
        In addition, return whether prompt updates have been applied.
981
982
983
984
985
986
987
988
989
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
990

991
        prompt_ids, = processed_data.pop("input_ids").tolist()
992

993
994
995
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
996
        )
997

998
        is_update_applied = self._hf_processor_applies_updates(
999
1000
1001
1002
1003
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1004
        return prompt_ids, mm_kwargs, is_update_applied
1005

1006
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1007
        """
1008
        Apply the HF processor on the prompt text only.
1009

1010
1011
1012
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1013
        """
1014
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1015
1016
1017
1018
1019
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1051
1052
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
1053
            mm_counts,
1054
        )
1055

1056
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1057
            prompt_text=dummy_inputs.prompt_text,
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1070
        enable_hf_prompt_update: bool,
1071
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1072
1073
1074
        """
        Apply the HF processor on the prompt text and multi-modal data.

1075
        In addition, return whether prompt updates have been applied
1076
1077
        (for most HF processors, this should be :code:`True`).

1078
        Note:
1079
1080
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1081
            that the prompt corresponds to multi-modal items.
1082
1083
        """
        if isinstance(prompt, str):
1084
            if enable_hf_prompt_update:
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1095
        mm_kwargs = self._apply_hf_processor_mm_only(
1096
            mm_items=mm_items,
1097
1098
1099
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1100
        return prompt_ids, mm_kwargs, False
1101
1102
1103

    def _cached_apply_hf_processor(
        self,
1104
        prompt: Union[str, list[int]],
1105
1106
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1107
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1108
1109
1110
1111
1112
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1113
        model_id = self.info.model_id
1114

1115
1116
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1117
1118
            return self._apply_hf_processor_main(
                prompt=prompt,
1119
1120
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1121
                enable_hf_prompt_update=True,
1122
1123
            )

1124
        mm_maybe_cached_kw_items = {
1125
1126
1127
1128
1129
1130
1131
1132
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1133
1134
1135
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1136
1137
1138
1139
1140
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1141
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1142

1143
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1144
        # so we can't apply prompt updates until the new multimodal
1145
1146
1147
1148
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1149
            is_update_applied,
1150
        ) = self._apply_hf_processor_main(
1151
1152
            prompt=prompt,
            mm_items=mm_missing_data_items,
1153
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1154
            enable_hf_prompt_update=False,
1155
1156
1157
1158
1159
1160
1161
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1162
1163
1164
1165
1166
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1167
1168
1169
1170
1171
1172
1173
1174
1175
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1176
                        kw_item,
1177
1178
1179
1180
                    )

                    mm_missing_next_idx[modality] += 1

1181
                merged_kw_items.append(kw_item)
1182
1183

        if self.enable_sanity_checks:
1184
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1185
1186
1187
1188
1189
1190
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1191
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1192

1193
        return prompt_ids, mm_kwargs, is_update_applied
1194

1195
    def _bind_and_group_updates(
1196
        self,
1197
1198
        prompt_updates: list[PromptUpdate],
    ) -> dict[str, list[BoundPromptUpdate]]:
1199
        tokenizer = self.info.get_tokenizer()
1200

1201
        it = (update.bind(tokenizer) for update in prompt_updates)
1202
        return dict(full_groupby_modality(it))
1203

1204
    def _apply_prompt_updates(
1205
1206
        self,
        token_ids: list[int],
1207
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1208
        mm_item_counts: Mapping[str, int],
1209
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1210
        tokenizer = self.info.get_tokenizer()
1211

1212
        mm_token_matches = {
1213
1214
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1215
        }
1216
1217
        mm_match_counts = {
            modality: len(matches)
1218
            for modality, matches in mm_token_matches.items()
1219
        }
1220
1221
1222
1223
1224
1225
1226
1227
1228

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1229
1230
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1231
        if all(
1232
1233
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1234
        ):  # yapf: disable
1235
            token_ids = apply_token_matches(
1236
                token_ids,
1237
                mm_token_matches,
1238
                mm_item_counts,
1239
1240
            )

1241
            text = decode_tokens(tokenizer, token_ids)
1242
1243
            matched_updates = {
                modality: [match._origin for match in token_matches]
1244
1245
                for modality, token_matches in mm_token_matches.items()
            }
1246
        else:
1247
            text = decode_tokens(tokenizer, token_ids)
1248

1249
            mm_text_matches = {
1250
1251
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1252
            }
1253
            text = apply_text_matches(
1254
                text,
1255
                mm_text_matches,
1256
                mm_item_counts,
1257
1258
            )

1259
1260
1261
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1262
1263
            matched_updates = {
                modality: [match._origin for match in token_matches]
1264
1265
1266
1267
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1268
            matched_updates,
1269
1270
1271
            token_ids,
            mm_item_counts,
        )
1272
1273

        return token_ids, text, placeholders
1274

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1298
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1299
        mm_item_counts: Mapping[str, int],
1300
    ) -> None:
1301
1302
1303
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1304
            if len(placeholders) != item_count:
1305
                raise RuntimeError(
1306
                    f"Expected there to be {item_count} prompt updates "
1307
                    f"corresponding to {item_count} {modality} items, but "
1308
                    f"instead found {len(placeholders)} prompt updates! "
1309
                    "Either the prompt text has missing/incorrect tokens for "
1310
1311
1312
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1313
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1314

1315
1316
    def apply(
        self,
1317
        prompt: Union[str, list[int]],
1318
        mm_data: MultiModalDataDict,
1319
        hf_processor_mm_kwargs: Mapping[str, object],
1320
    ) -> MultiModalInputs:
1321
1322
1323
1324
1325
1326
1327
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1328
        2. Find and update sequences in the token IDs with placeholder tokens.
1329
1330
1331
1332
1333
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1334
        mm_items = self._to_mm_items(mm_data)
1335

1336
1337
1338
1339
1340
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1341
            model_id = self.info.model_id
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1354
1355
1356
        (
            prompt_ids,
            mm_kwargs,
1357
            is_update_applied,
1358
        ) = self._cached_apply_hf_processor(
1359
            prompt,
1360
1361
1362
            mm_items,
            hf_processor_mm_kwargs,
        )
1363

1364
        unbound_prompt_updates = self._get_prompt_updates(
1365
1366
1367
1368
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1369
1370
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1371

1372
        mm_item_counts = mm_items.get_all_counts()
1373
1374
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1375
        if is_update_applied:
1376
            mm_placeholders = self._find_mm_placeholders(
1377
                mm_prompt_updates,
1378
                prompt_ids,
1379
1380
                mm_item_counts,
            )
1381
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1382

1383
            tokenizer = self.info.get_tokenizer()
1384
            prompt = decode_tokens(tokenizer, prompt_ids)
1385
1386
1387
        else:
            (
                prompt_ids,
1388
                prompt,
1389
                mm_placeholders,
1390
            ) = self._apply_prompt_updates(
1391
                prompt_ids,
1392
                mm_prompt_updates,
1393
                mm_item_counts,
1394
            )
1395
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1396
1397
1398
1399
1400

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1401

1402
        return MultiModalInputs(
1403
            type="multimodal",
1404
            prompt=prompt,
1405
            prompt_token_ids=prompt_ids,
1406
            mm_kwargs=mm_kwargs,
1407
            mm_hashes=mm_hashes,
1408
            mm_placeholders=mm_placeholder_ranges,
1409
        )
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1420
1421
1422
1423
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1424
1425
        raise NotImplementedError

1426
1427
1428
1429
1430
1431
1432
1433
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
        )

        tokenizer = self.info.get_tokenizer()
1455
1456
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1457
            decoder_prompt_ids = encode_tokens(tokenizer,
1458
                                               decoder_prompt,
1459
1460
                                               add_special_tokens=False)
        else:
1461
1462
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs