processing.py 52.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import re
3
import sys
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
12
                    TypeVar, Union, cast)
13

14
15
import torch
from cachetools import LRUCache
16
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
17
from typing_extensions import assert_never
18

19
from vllm.inputs import InputProcessingContext
20
from vllm.jsontree import json_map_leaves, json_reduce_leaves
21
from vllm.logger import init_logger
22
23
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
24
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
25

26
from .hasher import MultiModalHasher
27
28
29
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
30
31
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
32
33
34

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
35

36
logger = init_logger(__name__)
37
38

_S = TypeVar("_S", str, list[int])
39
40
41

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
42

43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
80
81
82
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


105
@dataclass
106
107
class PromptUpdateDetails:
    """Details about the token sequence or text that are part of the update."""
108
109

    full: PromptSeq
110
    """The full content."""
111

112
    features: PromptSeq
113
    """
114
    The part of the content that corresponds to feature placeholders;
115
116
    this will be replaced by the output of the vision encoder during model
    inference.
117
118
119
    """

    @staticmethod
120
121
    def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
        return PromptUpdateDetails(full=seq, features=seq)
122
123


124
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
125
"""
126
The token sequence or text that are part of the update.
127

128
129
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
130
"""
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
149
class PromptUpdate(ABC):
150
151
152
153
154
155
156
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

157
    target: PromptTarget
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

178

179
@dataclass
180
181
182
183
184
185
186
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
187
        equal to the feature size of the vision encoder after the ``<s>`` token:
188
189
190
191
192

        .. code-block:: python

            PromptInsertion(
                modality="image",
193
                target="<s>",
194
195
196
                insertion="<image>" * image_feature_size,
            )

197
        Insert these tokens at the start of the prompt:
198
199
200
201
202

        .. code-block:: python

            PromptInsertion(
                modality="image",
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
248
249
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
274
                replacement=PromptUpdateDetails(
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
292
                replacement=PromptUpdateDetails(
293
294
295
296
297
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
298
299
    """

300
    replacement: PromptUpdateContent = field(repr=False)
301
    """
302
    Given the index of the processed item within :attr:`modality`,
303
    output the token sequence (or text) to replace :attr:`target`.
304

305
306
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
307
308
    """

309
310
311
312
313
314
315
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
316
317


318
319
320
321
322
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
323
    add_special_tokens: Optional[bool] = None,
324
) -> list[int]:
325
326
327
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
328
329


330
331
332
333
334
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
335
    skip_special_tokens: Optional[bool] = None,
336
) -> str:
337
338
339
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
340
341
342
343
344


class _HasModalityAttr(Protocol):
    modality: str

345

346
class _HasModalityProp(Protocol):
347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
363
364
365
366
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
367
368
    tokenizer: AnyTokenizer = field(repr=False)

369
370
371
    _text: Optional[str]
    _token_ids: Optional[list[int]]

372
    @staticmethod
373
374
    def from_seq(
        tokenizer: AnyTokenizer,
375
        seq: PromptSeq,
376
    ) -> "_BoundPromptSequence":
377
378
379
380
381
382
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
400
401
402
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
403
404
405
406

        return self._token_ids


407
@dataclass
408
class _BoundPromptContent:
409
410
411
412
    full: _BoundPromptSequence
    features: _BoundPromptSequence


413
@dataclass
414
class BoundPromptUpdate:
415
    """
416
417
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
418
419
    token sequence and text representations.
    """
420
    _origin: PromptUpdate
421
    tokenizer: AnyTokenizer = field(repr=False)
422

423
    def __post_init__(self) -> None:
424
425
426
427
428
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
429
430

    @property
431
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
432
        """The token sequence (or text) to update."""
433
434
435
436
437
438
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
439

440
441
442
443
444
445
446
447
448
449
450
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
451
452
        """
        Given the index of the processed item within :attr:`modality`,
453
        output the token sequence (or text) to update.
454
        """
455
456
        content = self.content
        if callable(content):
457
            cache_key = item_idx
458
459
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
460

461
            content = content(item_idx)
462
463
464
        else:
            cache_key = None

465
466
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
467
468

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
469
                                                   content.full)
470
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
471
472
473
                                                       content.features)
        bound_content = _BoundPromptContent(full=bound_full,
                                            features=bound_features)
474
475

        if cache_key is not None:
476
            self._content_cache[cache_key] = bound_content
477

478
        return bound_content
479
480


481
482
483
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
484
485


486
487
488
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
489
) -> Generator[_TokenMatch]:
490
491
492
493
494
495
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
496
    match_len = len(match_ids)
497

498
499
    if match_len == 0:
        return
500

501
502
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
503
        end_idx = start_idx + match_len
504

505
506
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
507
508
509
510
511

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
512
513


514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


541
@dataclass(repr=False)
542
class PromptTargetMatch(ABC):
543
    _origin: BoundPromptUpdate
544
545
546

    @property
    def modality(self) -> str:
547
        return self._origin.modality
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


564
@dataclass(repr=False)
565
class _PromptTargetIndexMatch(PromptTargetMatch):
566
567
568
569
570
571
572
573
574
575
576
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


577
@dataclass(repr=False)
578
class _PromptTargetTokenMatch(PromptTargetMatch):
579
580
581
582
583
584
585
586
587
588
589
590
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
591
class _PromptTargetTextMatch(PromptTargetMatch):
592
593
594
595
596
597
598
599
600
601
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

602

603
@dataclass
604
class PlaceholderFeaturesInfo:
605
    modality: str
606
    item_idx: int
607
    start_idx: int
608
    tokens: list[int]
609
610
611

    @property
    def length(self) -> int:
612
        return len(self.tokens)
613
614
615
616
617
618

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
619
620
621
622


def find_token_matches(
    prompt: list[int],
623
    prompt_updates: Sequence[BoundPromptUpdate],
624
) -> Sequence[PromptTargetMatch]:
625
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

642
    return [
643
        match for update in prompt_updates for match in get_matches(update)
644
645
646
647
648
    ]


def find_text_matches(
    prompt: str,
649
    prompt_updates: Sequence[BoundPromptUpdate],
650
) -> Sequence[PromptTargetMatch]:
651
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

668
    return [
669
        match for update in prompt_updates for match in get_matches(update)
670
671
672
673
    ]


def _resolve_matches(
674
    prompt: PromptSeq,
675
676
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
677
    """
678
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
679
    and sort them such that earlier matches take priority over later ones.
680
    """
681
682
    matches = [m for matches in mm_matches.values() for m in matches]

683
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
684

685
    for match in matches:
686
687
688
689
690
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
691

692
            seen_matches[idx] = match
693
694
695
696

    return sorted(matches, key=lambda x: x.start_idx)


697
def _apply_matches(
698
    prompt: _S,
699
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
700
    mm_item_counts: Mapping[str, int],
701
) -> list[_S]:
702
703
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
704
    prev_end_idx = 0
705
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
728

729
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
730

731
        for item_idx in range(item_start_idx, item_end_idx):
732
            content = origin.get_content(item_idx)
733
734
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
735

736
            out_seqs.append(insert_seq)
737

738
739
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
740
741
742

    out_seqs.append(prompt[prev_end_idx:])

743
    return cast(list[_S], out_seqs)
744
745


746
def apply_token_matches(
747
    prompt: list[int],
748
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
749
    mm_item_counts: Mapping[str, int],
750
) -> list[int]:
751
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
752
    if not mm_matches:
753
754
        return prompt

755
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
756
757

    return flatten_2d_lists(token_id_seqs)
758
759


760
def apply_text_matches(
761
    prompt: str,
762
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
763
    mm_item_counts: Mapping[str, int],
764
) -> str:
765
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
766
    if not mm_matches:
767
        return prompt
768

769
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
770
771

    return "".join(texts)
772
773


774
def _iter_placeholders(
775
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
776
    prompt: list[int],
777
    mm_item_counts: Mapping[str, int],
778
) -> Iterable[PlaceholderFeaturesInfo]:
779
780
781
782
783
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
784
    appears earlier in `mm_prompt_updates` takes priority.
785

786
787
    Note that empty matches are ignored.
    """
788
    prompt_len = len(prompt)
789
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
790
791
792
793
794

    start_idx = 0
    while start_idx < prompt_len:
        found = False

795
        for modality, modality_updates in mm_prompt_updates.items():
796
797
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
798
                continue
799

800
801
802
803
804
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
805

806
                if content_len_full == 0 or end_idx_full > prompt_len:
807
808
                    continue

809
810
                if prompt[start_idx:end_idx_full] == content_tokens_full:
                    content_tokens_feat = content.features.token_ids
811
812
813

                    try:
                        match = next(
814
815
                            iter_token_matches(content_tokens_full,
                                               content_tokens_feat))
816
817
818
819
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
820
                            tokens=content_tokens_feat,
821
822
823
                        )
                    except StopIteration:
                        raise AssertionError(
824
825
                            f"{content_tokens_feat=} should be a "
                            f"subsequence of {content_tokens_full=}") from None
826

827
                    # Exclude overlapping matches
828
                    start_idx = end_idx_full
829
830
831
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
832

833
834
            if found:
                break  # Go back to the outer while loop
835
836
837

        if not found:
            start_idx += 1
838
839


840
def find_mm_placeholders(
841
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
842
843
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
844
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
845
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
846
847
848
    return dict(full_groupby_modality(it))


849
850
851
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


852
853
class ProcessingCache:

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
    @staticmethod
    def get_lru_cache(
        capacity_gb: int,
        value_type: type[_V],
    ) -> LRUCache[str, _V]:

        def get_size(leaf: object) -> int:
            if isinstance(leaf, torch.Tensor):
                return leaf.nbytes  # sys.getsizeof doesn't work for tensors

            return sys.getsizeof(leaf)

        return LRUCache[str, _V](
            GiB_bytes * capacity_gb,
            getsizeof=lambda x: json_reduce_leaves(
                lambda a, b: a + b,
                json_map_leaves(get_size, x),
            ),
        )

    def __init__(self, capacity_gb: int) -> None:
875
876
877
878
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None
879
880
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
881

882
        self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
883
884
885
886
887
888

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

889
890
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
891
            logger.debug("ProcessingCache: hit_ratio = %.2f",
892
                         self.debug_cache_hits / total)
893
894
895
896
897
898
899

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
900
    ) -> Optional[MultiModalKwargsItem]:
901
902
903
904
905
906
907
908
909
910
911
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

912
913
914
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
915
916
917
918
919
920
921

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

922
923
924
925
926
927
928
929
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
930
        output_kwargs: MultiModalKwargsItem,
931
932
933
934
935
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
936
937
938
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
939
        self._cache[cache_key] = output_kwargs
940
941


942
class BaseProcessingInfo:
943
    """Base class to provide the information necessary for data processing."""
944

945
946
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
947

948
949
950
951
952
953
954
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
955
956
        return self.ctx.tokenizer

957
    def get_hf_config(self) -> PretrainedConfig:
958
959
        return self.ctx.get_hf_config()

960
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
961
962
963
964
965
966
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

967
968
969
970
971
972
973
974
975
976
977
978
979
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
980
981
982
983
984
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
985
986
987
988
989
990
991
992
993
994
995
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
996

997
998

class BaseMultiModalProcessor(ABC, Generic[_I]):
999
    """
1000
    Abstract base class to process multi-modal inputs to be used in vLLM.
1001
1002

    Not to be confused with :class:`transformers.ProcessorMixin`.
1003
1004
    """

1005
    def __init__(self,
1006
1007
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1008
1009
1010
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1011
1012
1013
1014
1015
1016
        if get_repls := getattr(self, "_get_prompt_replacements", None):
            logger.warning_once("`_get_prompt_replacements` has been renamed "
                                "to `_get_prompt_updates`. The old name will "
                                "be removed in an upcoming release.")
            self._get_prompt_updates = get_repls  # type: ignore[method-assign]

1017
1018
        super().__init__()

1019
1020
        self.info = info
        self.dummy_inputs = dummy_inputs
1021
1022
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1023

1024
1025
        self.data_parser = self._get_data_parser()

1026
    def __call__(
1027
        self,
1028
1029
        prompt: str,
        mm_data: MultiModalDataDict,
1030
        hf_processor_mm_kwargs: Mapping[str, object],
1031
    ) -> MultiModalInputs:
1032
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1033

1034
1035
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1036
        Construct a parser to preprocess multi-modal data items
1037
1038
1039
1040
1041
1042
1043
1044
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1045
1046
1047
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1048
1049
1050
1051
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1052
        mm_items = self.data_parser.parse_mm_data(mm_data)
1053
        mm_config = self.info.ctx.get_mm_config()
1054
1055

        for modality, items in mm_items.items():
1056
            limit = mm_config.get_limit_per_prompt(modality)
1057
1058
1059
1060
1061
1062
1063
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1074
    @abstractmethod
1075
    def _get_prompt_updates(
1076
        self,
1077
        mm_items: MultiModalDataItems,
1078
1079
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1080
    ) -> Sequence[PromptUpdate]:
1081
1082
        """
        Given the original multi-modal items for this modality
1083
        and HF-processed data, output the updates to perform.
1084

1085
1086
1087
1088
1089
1090
1091
1092
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1093
1094
        """
        raise NotImplementedError
1095

1096
    def _find_mm_placeholders(
1097
        self,
1098
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1099
        new_token_ids: list[int],
1100
        mm_item_counts: Mapping[str, int],
1101
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1102
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1103
                                    mm_item_counts)
1104

1105
    def _get_hf_mm_data(
1106
        self,
1107
        mm_items: MultiModalDataItems,
1108
1109
1110
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1111

1112
1113
1114
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1115

1116
1117
        return processor_data, passthrough_data

1118
1119
1120
    def _call_hf_processor(
        self,
        prompt: str,
1121
1122
1123
1124
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1125
    ) -> BatchFeature:
1126
1127
1128
1129
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1130
1131
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1132
1133
            dict(text=prompt, **mm_data),
            mm_kwargs,
1134
1135
        )

1136
    def _hf_processor_applies_updates(
1137
1138
1139
1140
1141
1142
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1143
        Return whether the HF processor applies prompt updates.
1144
1145
1146
1147
1148
1149
1150
1151
1152

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1153
    def _apply_hf_processor_text_mm(
1154
        self,
1155
        prompt_text: str,
1156
        mm_items: MultiModalDataItems,
1157
        hf_processor_mm_kwargs: Mapping[str, object],
1158
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1159
        """
1160
1161
        Apply the HF processor on the prompt text and multi-modal data
        together.
1162

1163
        In addition, return whether prompt updates have been applied.
1164
1165
1166
1167
1168
1169
1170
1171
1172
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1173

1174
        prompt_ids, = processed_data.pop("input_ids").tolist()
1175

1176
1177
1178
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1179
        )
1180

1181
        is_update_applied = self._hf_processor_applies_updates(
1182
1183
1184
1185
1186
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1187
        return prompt_ids, mm_kwargs, is_update_applied
1188

1189
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1190
        """
1191
        Apply the HF processor on the prompt text only.
1192

1193
1194
1195
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1196
        """
1197
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1198
1199
1200
1201
1202
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1234
1235
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
1236
            mm_counts,
1237
        )
1238

1239
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1240
            prompt_text=dummy_inputs.prompt_text,
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1253
        enable_hf_prompt_update: bool,
1254
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1255
1256
1257
        """
        Apply the HF processor on the prompt text and multi-modal data.

1258
        In addition, return whether prompt updates have been applied
1259
1260
        (for most HF processors, this should be :code:`True`).

1261
        Note:
1262
1263
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1264
            that the prompt corresponds to multi-modal items.
1265
1266
        """
        if isinstance(prompt, str):
1267
            if enable_hf_prompt_update:
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1278
        mm_kwargs = self._apply_hf_processor_mm_only(
1279
            mm_items=mm_items,
1280
1281
1282
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1283
        return prompt_ids, mm_kwargs, False
1284
1285
1286

    def _cached_apply_hf_processor(
        self,
1287
        prompt: Union[str, list[int]],
1288
1289
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1290
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1291
1292
1293
1294
1295
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1296
        model_id = self.info.model_id
1297

1298
1299
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1300
1301
            return self._apply_hf_processor_main(
                prompt=prompt,
1302
1303
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1304
                enable_hf_prompt_update=True,
1305
1306
            )

1307
        mm_maybe_cached_kw_items = {
1308
1309
1310
1311
1312
1313
1314
1315
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1316
1317
1318
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1319
1320
1321
1322
1323
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1324
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1325

1326
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1327
        # so we can't apply prompt updates until the new multimodal
1328
1329
1330
1331
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1332
            is_update_applied,
1333
        ) = self._apply_hf_processor_main(
1334
1335
            prompt=prompt,
            mm_items=mm_missing_data_items,
1336
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1337
            enable_hf_prompt_update=False,
1338
1339
1340
1341
1342
1343
1344
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1345
1346
1347
1348
1349
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1350
1351
1352
1353
1354
1355
1356
1357
1358
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1359
                        kw_item,
1360
1361
1362
1363
                    )

                    mm_missing_next_idx[modality] += 1

1364
                merged_kw_items.append(kw_item)
1365
1366

        if self.enable_sanity_checks:
1367
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1368
1369
1370
1371
1372
1373
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1374
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1375

1376
        return prompt_ids, mm_kwargs, is_update_applied
1377

1378
    def _bind_and_group_updates(
1379
        self,
1380
1381
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1382
        tokenizer = self.info.get_tokenizer()
1383

1384
        it = (update.bind(tokenizer) for update in prompt_updates)
1385
        return dict(full_groupby_modality(it))
1386

1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1403
    def _apply_prompt_updates(
1404
1405
        self,
        token_ids: list[int],
1406
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1407
        mm_item_counts: Mapping[str, int],
1408
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1409
        tokenizer = self.info.get_tokenizer()
1410

1411
        mm_token_matches = {
1412
1413
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1414
        }
1415
1416
        mm_match_counts = {
            modality: len(matches)
1417
            for modality, matches in mm_token_matches.items()
1418
        }
1419
1420
1421
1422
1423
1424
1425
1426
1427

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1428
1429
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1430
        if all(
1431
1432
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1433
        ):  # yapf: disable
1434
            token_ids = self._apply_token_matches(
1435
                token_ids,
1436
                mm_token_matches,
1437
                mm_item_counts,
1438
1439
            )

1440
            text = decode_tokens(tokenizer, token_ids)
1441
1442
            matched_updates = {
                modality: [match._origin for match in token_matches]
1443
1444
                for modality, token_matches in mm_token_matches.items()
            }
1445
        else:
1446
            text = decode_tokens(tokenizer, token_ids)
1447

1448
            mm_text_matches = {
1449
1450
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1451
            }
1452
            text = self._apply_text_matches(
1453
                text,
1454
                mm_text_matches,
1455
                mm_item_counts,
1456
1457
            )

1458
1459
1460
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1461
1462
            matched_updates = {
                modality: [match._origin for match in token_matches]
1463
1464
1465
1466
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1467
            matched_updates,
1468
1469
1470
            token_ids,
            mm_item_counts,
        )
1471
1472

        return token_ids, text, placeholders
1473

1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1497
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1498
        mm_item_counts: Mapping[str, int],
1499
    ) -> None:
1500
1501
1502
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1503
            if len(placeholders) != item_count:
1504
                raise RuntimeError(
1505
                    f"Expected there to be {item_count} prompt updates "
1506
                    f"corresponding to {item_count} {modality} items, but "
1507
                    f"instead found {len(placeholders)} prompt updates! "
1508
                    "Either the prompt text has missing/incorrect tokens for "
1509
1510
1511
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1512
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1513

1514
1515
    def apply(
        self,
1516
        prompt: Union[str, list[int]],
1517
        mm_data: MultiModalDataDict,
1518
        hf_processor_mm_kwargs: Mapping[str, object],
1519
        return_mm_hashes: bool = False,
1520
    ) -> MultiModalInputs:
1521
1522
1523
1524
1525
1526
1527
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1528
        2. Find and update sequences in the token IDs with placeholder tokens.
1529
1530
1531
1532
1533
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1534
        mm_items = self._to_mm_items(mm_data)
1535

1536
        # Create MM hashes to be returned (only used in V1)
1537
1538
1539
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

1540
        if return_mm_hashes:
1541
            model_id = self.info.model_id
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1554
1555
1556
        (
            prompt_ids,
            mm_kwargs,
1557
            is_update_applied,
1558
        ) = self._cached_apply_hf_processor(
1559
            prompt,
1560
1561
1562
            mm_items,
            hf_processor_mm_kwargs,
        )
1563

1564
        unbound_prompt_updates = self._get_prompt_updates(
1565
1566
1567
1568
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1569
1570
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1571

1572
        mm_item_counts = mm_items.get_all_counts()
1573
1574
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1575
        if is_update_applied:
1576
            mm_placeholders = self._find_mm_placeholders(
1577
                mm_prompt_updates,
1578
                prompt_ids,
1579
1580
                mm_item_counts,
            )
1581
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1582

1583
            tokenizer = self.info.get_tokenizer()
1584
            prompt = decode_tokens(tokenizer, prompt_ids)
1585
1586
1587
        else:
            (
                prompt_ids,
1588
                prompt,
1589
                mm_placeholders,
1590
            ) = self._apply_prompt_updates(
1591
                prompt_ids,
1592
                mm_prompt_updates,
1593
                mm_item_counts,
1594
            )
1595
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1596
1597
1598
1599
1600

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1601

1602
        return MultiModalInputs(
1603
            type="multimodal",
1604
            prompt=prompt,
1605
            prompt_token_ids=prompt_ids,
1606
            mm_kwargs=mm_kwargs,
1607
            mm_hashes=mm_hashes,
1608
            mm_placeholders=mm_placeholder_ranges,
1609
        )
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1620
1621
1622
1623
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1624
1625
        raise NotImplementedError

1626
1627
1628
1629
1630
1631
1632
1633
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1634
1635
1636
1637
1638
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1639
        return_mm_hashes: bool = False,
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1653
            return_mm_hashes,
1654
1655
1656
        )

        tokenizer = self.info.get_tokenizer()
1657
1658
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1659
            decoder_prompt_ids = encode_tokens(tokenizer,
1660
                                               decoder_prompt,
1661
1662
                                               add_special_tokens=False)
        else:
1663
1664
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs