processing.py 54.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import re
3
import sys
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
12
                    TypeVar, Union, cast)
13

14
import torch
15
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
16
from typing_extensions import assert_never
17

18
from vllm.inputs import InputProcessingContext
19
from vllm.jsontree import json_map_leaves, json_reduce_leaves
20
from vllm.logger import init_logger
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
23
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
24

25
from .hasher import MultiModalHasher
26
27
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
28
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32
33

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
34

35
logger = init_logger(__name__)
36
37

_S = TypeVar("_S", str, list[int])
38
39
40

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
41

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
79
80
81
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


104
@dataclass
105
class PromptUpdateDetails(Generic[_S]):
106
    """Details about the token sequence or text that are part of the update."""
107

108
    full: _S
109
    """The full content."""
110

111
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
112
    """
113
114
115
116
117
118
119
    Given :attr:`full`, return a boolean mask of shape `(len(full),)`
    indicating which positions of `full` to assign embeddings to.

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
    :class:`SupportsMultiModal.get_multimodal_embeddings`.
120
121
122
    """

    @staticmethod
123
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
151
152


153
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
154
"""
155
The token sequence or text that are part of the update.
156

157
158
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
159
"""
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
178
class PromptUpdate(ABC):
179
180
181
182
183
184
185
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

186
    target: PromptTarget
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

207

208
@dataclass
209
210
211
212
213
214
215
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
216
        equal to the feature size of the vision encoder after the ``<s>`` token:
217
218
219
220
221

        .. code-block:: python

            PromptInsertion(
                modality="image",
222
                target="<s>",
223
224
225
                insertion="<image>" * image_feature_size,
            )

226
        Insert these tokens at the start of the prompt:
227
228
229
230
231

        .. code-block:: python

            PromptInsertion(
                modality="image",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
277
278
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
303
                replacement=PromptUpdateDetails(
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
321
                replacement=PromptUpdateDetails(
322
323
324
325
326
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
327
328
    """

329
    replacement: PromptUpdateContent = field(repr=False)
330
    """
331
    Given the index of the processed item within :attr:`modality`,
332
    output the token sequence (or text) to replace :attr:`target`.
333

334
335
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
336
337
    """

338
339
340
341
342
343
344
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
345
346


347
348
349
350
351
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
352
    add_special_tokens: Optional[bool] = None,
353
) -> list[int]:
354
355
356
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
357
358


359
360
361
362
363
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
364
    skip_special_tokens: Optional[bool] = None,
365
) -> str:
366
367
368
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
369
370
371
372
373


class _HasModalityAttr(Protocol):
    modality: str

374

375
class _HasModalityProp(Protocol):
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
392
393
394
395
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
396
397
    tokenizer: AnyTokenizer = field(repr=False)

398
399
400
    _text: Optional[str]
    _token_ids: Optional[list[int]]

401
    @staticmethod
402
403
    def from_seq(
        tokenizer: AnyTokenizer,
404
        seq: PromptSeq,
405
    ) -> "_BoundPromptSequence":
406
407
408
409
410
411
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
429
430
431
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
432
433
434
435

        return self._token_ids


436
@dataclass
437
class _BoundPromptContent:
438
    full: _BoundPromptSequence
439
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
440
441


442
@dataclass
443
class BoundPromptUpdate:
444
    """
445
446
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
447
448
    token sequence and text representations.
    """
449
    _origin: PromptUpdate
450
    tokenizer: AnyTokenizer = field(repr=False)
451

452
    def __post_init__(self) -> None:
453
454
455
456
457
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
458
459

    @property
460
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
461
        """The token sequence (or text) to update."""
462
463
464
465
466
467
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
468

469
470
471
472
473
474
475
476
477
478
479
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
480
481
        """
        Given the index of the processed item within :attr:`modality`,
482
        output the token sequence (or text) to update.
483
        """
484
485
        content = self.content
        if callable(content):
486
            cache_key = item_idx
487
488
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
489

490
            content = content(item_idx)
491
492
493
        else:
            cache_key = None

494
495
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
496
497

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
498
499
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
500
                                            is_embed=content.is_embed)
501
502

        if cache_key is not None:
503
            self._content_cache[cache_key] = bound_content
504

505
        return bound_content
506
507


508
509
510
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
511
512


513
514
515
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
516
) -> Generator[_TokenMatch]:
517
518
519
520
521
522
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
523
    match_len = len(match_ids)
524

525
526
    if match_len == 0:
        return
527

528
529
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
530
        end_idx = start_idx + match_len
531

532
533
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
534
535
536
537
538

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
539
540


541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


568
@dataclass(repr=False)
569
class PromptTargetMatch(ABC):
570
    _origin: BoundPromptUpdate
571
572
573

    @property
    def modality(self) -> str:
574
        return self._origin.modality
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


591
@dataclass(repr=False)
592
class _PromptTargetIndexMatch(PromptTargetMatch):
593
594
595
596
597
598
599
600
601
602
603
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


604
@dataclass(repr=False)
605
class _PromptTargetTokenMatch(PromptTargetMatch):
606
607
608
609
610
611
612
613
614
615
616
617
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
618
class _PromptTargetTextMatch(PromptTargetMatch):
619
620
621
622
623
624
625
626
627
628
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

629

630
@dataclass
631
class PlaceholderFeaturesInfo:
632
    modality: str
633
    item_idx: int
634
    start_idx: int
635
    tokens: list[int]
636
    is_embed: Optional[torch.Tensor]
637
638
639

    @property
    def length(self) -> int:
640
        return len(self.tokens)
641
642

    def to_range(self) -> PlaceholderRange:
643
644
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
645
646
647
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
648
            is_embed=self.is_embed,
649
        )
650
651
652
653


def find_token_matches(
    prompt: list[int],
654
    prompt_updates: Sequence[BoundPromptUpdate],
655
) -> Sequence[PromptTargetMatch]:
656
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

673
    return [
674
        match for update in prompt_updates for match in get_matches(update)
675
676
677
678
679
    ]


def find_text_matches(
    prompt: str,
680
    prompt_updates: Sequence[BoundPromptUpdate],
681
) -> Sequence[PromptTargetMatch]:
682
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

699
    return [
700
        match for update in prompt_updates for match in get_matches(update)
701
702
703
704
    ]


def _resolve_matches(
705
    prompt: PromptSeq,
706
707
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
708
    """
709
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
710
    and sort them such that earlier matches take priority over later ones.
711
    """
712
713
    matches = [m for matches in mm_matches.values() for m in matches]

714
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
715

716
    for match in matches:
717
718
719
720
721
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
722

723
            seen_matches[idx] = match
724
725
726
727

    return sorted(matches, key=lambda x: x.start_idx)


728
def _apply_matches(
729
    prompt: _S,
730
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
731
    mm_item_counts: Mapping[str, int],
732
) -> list[_S]:
733
734
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
735
    prev_end_idx = 0
736
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
759

760
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
761

762
        for item_idx in range(item_start_idx, item_end_idx):
763
            content = origin.get_content(item_idx)
764
765
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
766

767
            out_seqs.append(insert_seq)
768

769
770
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
771
772
773

    out_seqs.append(prompt[prev_end_idx:])

774
    return cast(list[_S], out_seqs)
775
776


777
def apply_token_matches(
778
    prompt: list[int],
779
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
780
    mm_item_counts: Mapping[str, int],
781
) -> list[int]:
782
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
783
    if not mm_matches:
784
785
        return prompt

786
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
787
788

    return flatten_2d_lists(token_id_seqs)
789
790


791
def apply_text_matches(
792
    prompt: str,
793
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
794
    mm_item_counts: Mapping[str, int],
795
) -> str:
796
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
797
    if not mm_matches:
798
        return prompt
799

800
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
801
802

    return "".join(texts)
803
804


805
def _iter_placeholders(
806
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
807
    prompt: list[int],
808
    mm_item_counts: Mapping[str, int],
809
) -> Iterable[PlaceholderFeaturesInfo]:
810
811
812
813
814
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
815
    appears earlier in `mm_prompt_updates` takes priority.
816

817
818
    Note that empty matches are ignored.
    """
819
    prompt_len = len(prompt)
820
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
821
822
823
824
825

    start_idx = 0
    while start_idx < prompt_len:
        found = False

826
        for modality, modality_updates in mm_prompt_updates.items():
827
828
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
829
                continue
830

831
832
833
834
835
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
836

837
                if content_len_full == 0 or end_idx_full > prompt_len:
838
839
                    continue

840
                if prompt[start_idx:end_idx_full] == content_tokens_full:
841
842
843
844
845
846
847
848
849
850
851
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
852

853
                    # Exclude overlapping matches
854
                    start_idx = end_idx_full
855
856
857
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
858

859
860
            if found:
                break  # Go back to the outer while loop
861
862
863

        if not found:
            start_idx += 1
864
865


866
def find_mm_placeholders(
867
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
868
869
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
870
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
871
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
872
873
874
    return dict(full_groupby_modality(it))


875
876
877
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


878
879
class ProcessingCache:

880
881
    @staticmethod
    def get_lru_cache(
882
        capacity_gb: float,
883
        value_type: type[_V],
884
885
        *,
        debug: bool = False,
886
887
    ) -> LRUCache[str, _V]:

888
889
890
891
892
893
894
895
896
897
898
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
899
            if isinstance(leaf, torch.Tensor):
900
                return leaf.nbytes
901
902
903

            return sys.getsizeof(leaf)

904
905
906
907
908
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
909
                lambda a, b: a + b,
910
911
912
913
914
915
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
916

917
918
919
920
921
922
923
924
925
926
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
927
928
        super().__init__()

929
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
930
931
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
932

933
934
935
936
937
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
938
939
940
941
942
943

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

944
945
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
946
            logger.debug("ProcessingCache: hit_ratio = %.2f",
947
                         self.debug_cache_hits / total)
948
949
950
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
951
952
953
954
955
956
957

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
958
    ) -> Optional[MultiModalKwargsItem]:
959
960
961
962
963
964
965
966
967
968
969
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

970
971
972
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
973
974
975
976
977
978
979

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

980
981
982
983
984
985
986
987
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
988
        output_kwargs: MultiModalKwargsItem,
989
990
991
992
993
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
994
995
996
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
997
        self._cache[cache_key] = output_kwargs
998
999


1000
class BaseProcessingInfo:
1001
    """Base class to provide the information necessary for data processing."""
1002

1003
1004
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1005

1006
1007
1008
1009
1010
1011
1012
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1013
1014
        return self.ctx.tokenizer

1015
    def get_hf_config(self) -> PretrainedConfig:
1016
1017
        return self.ctx.get_hf_config()

1018
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1019
1020
1021
1022
1023
1024
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
1038
1039
1040
1041
1042
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
1054

1055
1056

class BaseMultiModalProcessor(ABC, Generic[_I]):
1057
    """
1058
    Abstract base class to process multi-modal inputs to be used in vLLM.
1059
1060

    Not to be confused with :class:`transformers.ProcessorMixin`.
1061
1062
    """

1063
    def __init__(self,
1064
1065
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1066
1067
1068
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1069
1070
1071
1072
1073
1074
        if get_repls := getattr(self, "_get_prompt_replacements", None):
            logger.warning_once("`_get_prompt_replacements` has been renamed "
                                "to `_get_prompt_updates`. The old name will "
                                "be removed in an upcoming release.")
            self._get_prompt_updates = get_repls  # type: ignore[method-assign]

1075
1076
        super().__init__()

1077
1078
        self.info = info
        self.dummy_inputs = dummy_inputs
1079
1080
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1081

1082
1083
        self.data_parser = self._get_data_parser()

1084
    def __call__(
1085
        self,
1086
1087
        prompt: str,
        mm_data: MultiModalDataDict,
1088
        hf_processor_mm_kwargs: Mapping[str, object],
1089
    ) -> MultiModalInputs:
1090
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1091

1092
1093
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1094
        Construct a parser to preprocess multi-modal data items
1095
1096
1097
1098
1099
1100
1101
1102
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1103
1104
1105
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1106
1107
1108
1109
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1110
        mm_items = self.data_parser.parse_mm_data(mm_data)
1111
        mm_config = self.info.ctx.get_mm_config()
1112
1113

        for modality, items in mm_items.items():
1114
            limit = mm_config.get_limit_per_prompt(modality)
1115
1116
1117
1118
1119
1120
1121
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
1122

1123
1124
1125
1126
1127
1128
1129
1130
1131
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1132
    @abstractmethod
1133
    def _get_prompt_updates(
1134
        self,
1135
        mm_items: MultiModalDataItems,
1136
1137
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1138
    ) -> Sequence[PromptUpdate]:
1139
1140
        """
        Given the original multi-modal items for this modality
1141
        and HF-processed data, output the updates to perform.
1142

1143
1144
1145
1146
1147
1148
1149
1150
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1151
1152
        """
        raise NotImplementedError
1153

1154
    def _find_mm_placeholders(
1155
        self,
1156
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1157
        new_token_ids: list[int],
1158
        mm_item_counts: Mapping[str, int],
1159
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1160
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1161
                                    mm_item_counts)
1162

1163
    def _get_hf_mm_data(
1164
        self,
1165
        mm_items: MultiModalDataItems,
1166
1167
1168
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1169

1170
1171
1172
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1173

1174
1175
        return processor_data, passthrough_data

1176
1177
1178
    def _call_hf_processor(
        self,
        prompt: str,
1179
1180
1181
1182
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1183
    ) -> BatchFeature:
1184
1185
1186
1187
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1188
1189
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1190
1191
            dict(text=prompt, **mm_data),
            mm_kwargs,
1192
1193
        )

1194
    def _hf_processor_applies_updates(
1195
1196
1197
1198
1199
1200
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1201
        Return whether the HF processor applies prompt updates.
1202
1203
1204
1205
1206
1207
1208
1209
1210

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1211
    def _apply_hf_processor_text_mm(
1212
        self,
1213
        prompt_text: str,
1214
        mm_items: MultiModalDataItems,
1215
        hf_processor_mm_kwargs: Mapping[str, object],
1216
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1217
        """
1218
1219
        Apply the HF processor on the prompt text and multi-modal data
        together.
1220

1221
        In addition, return whether prompt updates have been applied.
1222
1223
1224
1225
1226
1227
1228
1229
1230
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1231

1232
        prompt_ids, = processed_data.pop("input_ids").tolist()
1233

1234
1235
1236
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1237
        )
1238

1239
        is_update_applied = self._hf_processor_applies_updates(
1240
1241
1242
1243
1244
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1245
        return prompt_ids, mm_kwargs, is_update_applied
1246

1247
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1248
        """
1249
        Apply the HF processor on the prompt text only.
1250

1251
1252
1253
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1254
        """
1255
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1256
1257
1258
1259
1260
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1292
1293
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
1294
            mm_counts,
1295
        )
1296

1297
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1298
            prompt_text=dummy_inputs.prompt_text,
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1311
        enable_hf_prompt_update: bool,
1312
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1313
1314
1315
        """
        Apply the HF processor on the prompt text and multi-modal data.

1316
        In addition, return whether prompt updates have been applied
1317
1318
        (for most HF processors, this should be :code:`True`).

1319
        Note:
1320
1321
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1322
            that the prompt corresponds to multi-modal items.
1323
1324
        """
        if isinstance(prompt, str):
1325
            if enable_hf_prompt_update:
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1336
        mm_kwargs = self._apply_hf_processor_mm_only(
1337
            mm_items=mm_items,
1338
1339
1340
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1341
        return prompt_ids, mm_kwargs, False
1342
1343
1344

    def _cached_apply_hf_processor(
        self,
1345
        prompt: Union[str, list[int]],
1346
1347
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1348
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1349
1350
1351
1352
1353
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1354
        model_id = self.info.model_id
1355

1356
1357
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1358
1359
            return self._apply_hf_processor_main(
                prompt=prompt,
1360
1361
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1362
                enable_hf_prompt_update=True,
1363
1364
            )

1365
        mm_maybe_cached_kw_items = {
1366
1367
1368
1369
1370
1371
1372
1373
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1374
1375
1376
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1377
1378
1379
1380
1381
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1382
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1383

1384
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1385
        # so we can't apply prompt updates until the new multimodal
1386
1387
1388
1389
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1390
            is_update_applied,
1391
        ) = self._apply_hf_processor_main(
1392
1393
            prompt=prompt,
            mm_items=mm_missing_data_items,
1394
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1395
            enable_hf_prompt_update=False,
1396
1397
1398
1399
1400
1401
1402
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1403
1404
1405
1406
1407
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1408
1409
1410
1411
1412
1413
1414
1415
1416
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1417
                        kw_item,
1418
1419
1420
1421
                    )

                    mm_missing_next_idx[modality] += 1

1422
                merged_kw_items.append(kw_item)
1423
1424

        if self.enable_sanity_checks:
1425
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1426
1427
1428
1429
1430
1431
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1432
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1433

1434
        return prompt_ids, mm_kwargs, is_update_applied
1435

1436
    def _bind_and_group_updates(
1437
        self,
1438
1439
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1440
        tokenizer = self.info.get_tokenizer()
1441

1442
        it = (update.bind(tokenizer) for update in prompt_updates)
1443
        return dict(full_groupby_modality(it))
1444

1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1461
    def _apply_prompt_updates(
1462
1463
        self,
        token_ids: list[int],
1464
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1465
        mm_item_counts: Mapping[str, int],
1466
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1467
        tokenizer = self.info.get_tokenizer()
1468

1469
        mm_token_matches = {
1470
1471
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1472
        }
1473
1474
        mm_match_counts = {
            modality: len(matches)
1475
            for modality, matches in mm_token_matches.items()
1476
        }
1477
1478
1479
1480
1481
1482
1483
1484
1485

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1486
1487
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1488
        if all(
1489
1490
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1491
        ):  # yapf: disable
1492
            token_ids = self._apply_token_matches(
1493
                token_ids,
1494
                mm_token_matches,
1495
                mm_item_counts,
1496
1497
            )

1498
            text = decode_tokens(tokenizer, token_ids)
1499
1500
            matched_updates = {
                modality: [match._origin for match in token_matches]
1501
1502
                for modality, token_matches in mm_token_matches.items()
            }
1503
        else:
1504
            text = decode_tokens(tokenizer, token_ids)
1505

1506
            mm_text_matches = {
1507
1508
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1509
            }
1510
            text = self._apply_text_matches(
1511
                text,
1512
                mm_text_matches,
1513
                mm_item_counts,
1514
1515
            )

1516
1517
1518
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1519
1520
            matched_updates = {
                modality: [match._origin for match in token_matches]
1521
1522
1523
1524
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1525
            matched_updates,
1526
1527
1528
            token_ids,
            mm_item_counts,
        )
1529
1530

        return token_ids, text, placeholders
1531

1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1555
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1556
        mm_item_counts: Mapping[str, int],
1557
    ) -> None:
1558
1559
1560
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1561
            if len(placeholders) != item_count:
1562
                raise RuntimeError(
1563
                    f"Expected there to be {item_count} prompt updates "
1564
                    f"corresponding to {item_count} {modality} items, but "
1565
                    f"instead found {len(placeholders)} prompt updates! "
1566
                    "Either the prompt text has missing/incorrect tokens for "
1567
1568
1569
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1570
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1571

1572
1573
    def apply(
        self,
1574
        prompt: Union[str, list[int]],
1575
        mm_data: MultiModalDataDict,
1576
        hf_processor_mm_kwargs: Mapping[str, object],
1577
        return_mm_hashes: bool = False,
1578
    ) -> MultiModalInputs:
1579
1580
1581
1582
1583
1584
1585
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1586
        2. Find and update sequences in the token IDs with placeholder tokens.
1587
1588
1589
1590
1591
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1592
        mm_items = self._to_mm_items(mm_data)
1593

1594
        # Create MM hashes to be returned (only used in V1)
1595
1596
1597
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

1598
        if return_mm_hashes:
1599
            model_id = self.info.model_id
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1612
1613
1614
        (
            prompt_ids,
            mm_kwargs,
1615
            is_update_applied,
1616
        ) = self._cached_apply_hf_processor(
1617
            prompt,
1618
1619
1620
            mm_items,
            hf_processor_mm_kwargs,
        )
1621

1622
        unbound_prompt_updates = self._get_prompt_updates(
1623
1624
1625
1626
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1627
1628
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1629

1630
        mm_item_counts = mm_items.get_all_counts()
1631
1632
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1633
        if is_update_applied:
1634
            mm_placeholders = self._find_mm_placeholders(
1635
                mm_prompt_updates,
1636
                prompt_ids,
1637
1638
                mm_item_counts,
            )
1639
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1640

1641
            tokenizer = self.info.get_tokenizer()
1642
            prompt = decode_tokens(tokenizer, prompt_ids)
1643
1644
1645
        else:
            (
                prompt_ids,
1646
                prompt,
1647
                mm_placeholders,
1648
            ) = self._apply_prompt_updates(
1649
                prompt_ids,
1650
                mm_prompt_updates,
1651
                mm_item_counts,
1652
            )
1653
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1654
1655
1656
1657
1658

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1659

1660
        return MultiModalInputs(
1661
            type="multimodal",
1662
            prompt=prompt,
1663
            prompt_token_ids=prompt_ids,
1664
            mm_kwargs=mm_kwargs,
1665
            mm_hashes=mm_hashes,
1666
            mm_placeholders=mm_placeholder_ranges,
1667
        )
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1678
1679
1680
1681
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1682
1683
        raise NotImplementedError

1684
1685
1686
1687
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1688
1689
1690
1691
1692
1693
1694
1695
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1696
1697
1698
1699
1700
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1701
        return_mm_hashes: bool = False,
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1715
            return_mm_hashes,
1716
1717
1718
        )

        tokenizer = self.info.get_tokenizer()
1719
1720
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1721
            decoder_prompt_ids = encode_tokens(tokenizer,
1722
                                               decoder_prompt,
1723
1724
                                               add_special_tokens=False)
        else:
1725
1726
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs