processing.py 54.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import re
3
import sys
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
12
                    TypeVar, Union, cast)
13

14
import torch
15
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
16
from typing_extensions import assert_never
17

18
from vllm.inputs import InputProcessingContext
19
from vllm.jsontree import json_map_leaves, json_reduce_leaves
20
from vllm.logger import init_logger
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
23
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
24

25
from .hasher import MultiModalHasher
26
27
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
28
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32
33

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
34

35
logger = init_logger(__name__)
36
37

_S = TypeVar("_S", str, list[int])
38
39
40

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
41

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
79
80
81
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


104
@dataclass
105
class PromptUpdateDetails(Generic[_S]):
106
    """Details about the token sequence or text that are part of the update."""
107

108
    full: _S
109
    """The full content."""
110

111
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
112
    """
113
114
115
116
117
118
119
    Given :attr:`full`, return a boolean mask of shape `(len(full),)`
    indicating which positions of `full` to assign embeddings to.

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
    :class:`SupportsMultiModal.get_multimodal_embeddings`.
120
121
122
    """

    @staticmethod
123
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
151
152


153
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
154
"""
155
The token sequence or text that are part of the update.
156

157
158
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
159
"""
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
178
class PromptUpdate(ABC):
179
180
181
182
183
184
185
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

186
    target: PromptTarget
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

207

208
@dataclass
209
210
211
212
213
214
215
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
216
        equal to the feature size of the vision encoder after the ``<s>`` token:
217
218
219
220
221

        .. code-block:: python

            PromptInsertion(
                modality="image",
222
                target="<s>",
223
224
225
                insertion="<image>" * image_feature_size,
            )

226
        Insert these tokens at the start of the prompt:
227
228
229
230
231

        .. code-block:: python

            PromptInsertion(
                modality="image",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
277
278
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
303
                replacement=PromptUpdateDetails(
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
321
                replacement=PromptUpdateDetails(
322
323
324
325
326
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
327
328
    """

329
    replacement: PromptUpdateContent = field(repr=False)
330
    """
331
    Given the index of the processed item within :attr:`modality`,
332
    output the token sequence (or text) to replace :attr:`target`.
333

334
335
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
336
337
    """

338
339
340
341
342
343
344
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
345
346


347
348
349
350
351
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
352
    add_special_tokens: Optional[bool] = None,
353
) -> list[int]:
354
355
356
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
357
358


359
360
361
362
363
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
364
    skip_special_tokens: Optional[bool] = None,
365
) -> str:
366
367
368
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
369
370
371
372
373


class _HasModalityAttr(Protocol):
    modality: str

374

375
class _HasModalityProp(Protocol):
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
392
393
394
395
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
396
397
    tokenizer: AnyTokenizer = field(repr=False)

398
399
400
    _text: Optional[str]
    _token_ids: Optional[list[int]]

401
    @staticmethod
402
403
    def from_seq(
        tokenizer: AnyTokenizer,
404
        seq: PromptSeq,
405
    ) -> "_BoundPromptSequence":
406
407
408
409
410
411
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
429
430
431
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
432
433
434
435

        return self._token_ids


436
@dataclass
437
class _BoundPromptContent:
438
    full: _BoundPromptSequence
439
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
440
441


442
@dataclass
443
class BoundPromptUpdate:
444
    """
445
446
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
447
448
    token sequence and text representations.
    """
449
    _origin: PromptUpdate
450
    tokenizer: AnyTokenizer = field(repr=False)
451

452
    def __post_init__(self) -> None:
453
454
455
456
457
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
458
459

    @property
460
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
461
        """The token sequence (or text) to update."""
462
463
464
465
466
467
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
468

469
470
471
472
473
474
475
476
477
478
479
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
480
481
        """
        Given the index of the processed item within :attr:`modality`,
482
        output the token sequence (or text) to update.
483
        """
484
485
        content = self.content
        if callable(content):
486
            cache_key = item_idx
487
488
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
489

490
            content = content(item_idx)
491
492
493
        else:
            cache_key = None

494
495
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
496
497

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
498
499
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
500
                                            is_embed=content.is_embed)
501
502

        if cache_key is not None:
503
            self._content_cache[cache_key] = bound_content
504

505
        return bound_content
506
507


508
509
510
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
511
512


513
514
515
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
516
) -> Generator[_TokenMatch]:
517
518
519
520
521
522
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
523
    match_len = len(match_ids)
524

525
526
    if match_len == 0:
        return
527

528
529
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
530
        end_idx = start_idx + match_len
531

532
533
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
534
535
536
537
538

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
539
540


541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


568
@dataclass(repr=False)
569
class PromptTargetMatch(ABC):
570
    _origin: BoundPromptUpdate
571
572
573

    @property
    def modality(self) -> str:
574
        return self._origin.modality
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


591
@dataclass(repr=False)
592
class _PromptTargetIndexMatch(PromptTargetMatch):
593
594
595
596
597
598
599
600
601
602
603
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


604
@dataclass(repr=False)
605
class _PromptTargetTokenMatch(PromptTargetMatch):
606
607
608
609
610
611
612
613
614
615
616
617
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
618
class _PromptTargetTextMatch(PromptTargetMatch):
619
620
621
622
623
624
625
626
627
628
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

629

630
@dataclass
631
class PlaceholderFeaturesInfo:
632
    modality: str
633
    item_idx: int
634
    start_idx: int
635
    tokens: list[int]
636
    is_embed: Optional[torch.Tensor]
637
638
639

    @property
    def length(self) -> int:
640
        return len(self.tokens)
641
642

    def to_range(self) -> PlaceholderRange:
643
644
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
645
646
647
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
648
            is_embed=self.is_embed,
649
        )
650
651
652
653


def find_token_matches(
    prompt: list[int],
654
    prompt_updates: Sequence[BoundPromptUpdate],
655
) -> Sequence[PromptTargetMatch]:
656
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

673
    return [
674
        match for update in prompt_updates for match in get_matches(update)
675
676
677
678
679
    ]


def find_text_matches(
    prompt: str,
680
    prompt_updates: Sequence[BoundPromptUpdate],
681
) -> Sequence[PromptTargetMatch]:
682
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

699
    return [
700
        match for update in prompt_updates for match in get_matches(update)
701
702
703
704
    ]


def _resolve_matches(
705
    prompt: PromptSeq,
706
707
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
708
    """
709
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
710
    and sort them such that earlier matches take priority over later ones.
711
    """
712
713
    matches = [m for matches in mm_matches.values() for m in matches]

714
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
715

716
    for match in matches:
717
718
719
720
721
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
722

723
            seen_matches[idx] = match
724
725
726
727

    return sorted(matches, key=lambda x: x.start_idx)


728
def _apply_matches(
729
    prompt: _S,
730
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
731
    mm_item_counts: Mapping[str, int],
732
) -> list[_S]:
733
734
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
735
    prev_end_idx = 0
736
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
759

760
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
761

762
        for item_idx in range(item_start_idx, item_end_idx):
763
            content = origin.get_content(item_idx)
764
765
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
766

767
            out_seqs.append(insert_seq)
768

769
770
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
771
772
773

    out_seqs.append(prompt[prev_end_idx:])

774
    return cast(list[_S], out_seqs)
775
776


777
def apply_token_matches(
778
    prompt: list[int],
779
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
780
    mm_item_counts: Mapping[str, int],
781
) -> list[int]:
782
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
783
    if not mm_matches:
784
785
        return prompt

786
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
787
788

    return flatten_2d_lists(token_id_seqs)
789
790


791
def apply_text_matches(
792
    prompt: str,
793
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
794
    mm_item_counts: Mapping[str, int],
795
) -> str:
796
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
797
    if not mm_matches:
798
        return prompt
799

800
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
801
802

    return "".join(texts)
803
804


805
def _iter_placeholders(
806
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
807
    prompt: list[int],
808
    mm_item_counts: Mapping[str, int],
809
) -> Iterable[PlaceholderFeaturesInfo]:
810
811
812
813
814
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
815
    appears earlier in `mm_prompt_updates` takes priority.
816

817
818
    Note that empty matches are ignored.
    """
819
    prompt_len = len(prompt)
820
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
821
822
823
824
825

    start_idx = 0
    while start_idx < prompt_len:
        found = False

826
        for modality, modality_updates in mm_prompt_updates.items():
827
828
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
829
                continue
830

831
832
833
834
835
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
836

837
                if content_len_full == 0 or end_idx_full > prompt_len:
838
839
                    continue

840
                if prompt[start_idx:end_idx_full] == content_tokens_full:
841
842
843
844
845
846
847
848
849
850
851
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
852

853
                    # Exclude overlapping matches
854
                    start_idx = end_idx_full
855
856
857
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
858

859
860
            if found:
                break  # Go back to the outer while loop
861
862
863

        if not found:
            start_idx += 1
864
865


866
def find_mm_placeholders(
867
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
868
869
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
870
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
871
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
872
873
874
    return dict(full_groupby_modality(it))


875
876
877
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


878
879
class ProcessingCache:

880
881
    @staticmethod
    def get_lru_cache(
882
        capacity_gb: float,
883
        value_type: type[_V],
884
885
        *,
        debug: bool = False,
886
887
    ) -> LRUCache[str, _V]:

888
889
890
891
892
893
894
895
896
897
898
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
899
            if isinstance(leaf, torch.Tensor):
900
                return leaf.nbytes
901
902
903

            return sys.getsizeof(leaf)

904
905
906
907
908
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
909
                lambda a, b: a + b,
910
911
912
913
914
915
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
916

917
918
919
920
921
922
923
924
925
926
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
927
928
        super().__init__()

929
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
930
931
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
932

933
934
935
936
937
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
938
939
940
941
942
943

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

944
945
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
946
            logger.debug("ProcessingCache: hit_ratio = %.2f",
947
                         self.debug_cache_hits / total)
948
949
950
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
951
952
953
954
955
956
957

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
958
    ) -> Optional[MultiModalKwargsItem]:
959
960
961
962
963
964
965
966
967
968
969
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

970
971
972
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
973
974
975
976
977
978
979

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

980
981
982
983
984
985
986
987
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
988
        output_kwargs: MultiModalKwargsItem,
989
990
991
992
993
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
994
995
996
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
997
        self._cache[cache_key] = output_kwargs
998
999


1000
class BaseProcessingInfo:
1001
    """Base class to provide the information necessary for data processing."""
1002

1003
1004
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1005

1006
1007
1008
1009
1010
1011
1012
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1013
1014
        return self.ctx.tokenizer

1015
    def get_hf_config(self) -> PretrainedConfig:
1016
1017
        return self.ctx.get_hf_config()

1018
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1019
1020
1021
1022
1023
1024
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
1039

1040
1041

class BaseMultiModalProcessor(ABC, Generic[_I]):
1042
    """
1043
    Abstract base class to process multi-modal inputs to be used in vLLM.
1044
1045

    Not to be confused with :class:`transformers.ProcessorMixin`.
1046
1047
    """

1048
    def __init__(self,
1049
1050
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1051
1052
1053
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1054
1055
1056
1057
1058
1059
        if get_repls := getattr(self, "_get_prompt_replacements", None):
            logger.warning_once("`_get_prompt_replacements` has been renamed "
                                "to `_get_prompt_updates`. The old name will "
                                "be removed in an upcoming release.")
            self._get_prompt_updates = get_repls  # type: ignore[method-assign]

1060
1061
        super().__init__()

1062
1063
        self.info = info
        self.dummy_inputs = dummy_inputs
1064
1065
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1066

1067
1068
        self.data_parser = self._get_data_parser()

1069
    def __call__(
1070
        self,
1071
1072
        prompt: str,
        mm_data: MultiModalDataDict,
1073
        hf_processor_mm_kwargs: Mapping[str, object],
1074
    ) -> MultiModalInputs:
1075
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1076

1077
1078
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1079
        Construct a parser to preprocess multi-modal data items
1080
1081
1082
1083
1084
1085
1086
1087
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1088
1089
1090
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1091
1092
1093
1094
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1095
        mm_items = self.data_parser.parse_mm_data(mm_data)
1096
        mm_config = self.info.ctx.get_mm_config()
1097
1098

        for modality, items in mm_items.items():
1099
            limit = mm_config.get_limit_per_prompt(modality)
1100
1101
1102
1103
1104
1105
1106
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
1107

1108
1109
1110
1111
1112
1113
1114
1115
1116
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1117
    @abstractmethod
1118
    def _get_prompt_updates(
1119
        self,
1120
        mm_items: MultiModalDataItems,
1121
1122
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1123
    ) -> Sequence[PromptUpdate]:
1124
1125
        """
        Given the original multi-modal items for this modality
1126
        and HF-processed data, output the updates to perform.
1127

1128
1129
1130
1131
1132
1133
1134
1135
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1136
1137
        """
        raise NotImplementedError
1138

1139
    def _find_mm_placeholders(
1140
        self,
1141
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1142
        new_token_ids: list[int],
1143
        mm_item_counts: Mapping[str, int],
1144
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1145
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1146
                                    mm_item_counts)
1147

1148
    def _get_hf_mm_data(
1149
        self,
1150
        mm_items: MultiModalDataItems,
1151
1152
1153
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1154

1155
1156
1157
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1158

1159
1160
        return processor_data, passthrough_data

1161
1162
1163
    def _call_hf_processor(
        self,
        prompt: str,
1164
1165
1166
1167
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1168
    ) -> BatchFeature:
1169
1170
1171
1172
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1173
1174
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1175
1176
            dict(text=prompt, **mm_data),
            mm_kwargs,
1177
1178
        )

1179
    def _hf_processor_applies_updates(
1180
1181
1182
1183
1184
1185
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1186
        Return whether the HF processor applies prompt updates.
1187
1188
1189
1190
1191
1192
1193
1194
1195

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1196
    def _apply_hf_processor_text_mm(
1197
        self,
1198
        prompt_text: str,
1199
        mm_items: MultiModalDataItems,
1200
        hf_processor_mm_kwargs: Mapping[str, object],
1201
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1202
        """
1203
1204
        Apply the HF processor on the prompt text and multi-modal data
        together.
1205

1206
        In addition, return whether prompt updates have been applied.
1207
1208
1209
1210
1211
1212
1213
1214
1215
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1216

1217
        prompt_ids, = processed_data.pop("input_ids").tolist()
1218

1219
1220
1221
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1222
        )
1223

1224
        is_update_applied = self._hf_processor_applies_updates(
1225
1226
1227
1228
1229
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1230
        return prompt_ids, mm_kwargs, is_update_applied
1231

1232
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1233
        """
1234
        Apply the HF processor on the prompt text only.
1235

1236
1237
1238
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1239
        """
1240
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1241
1242
1243
1244
1245
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1277
1278
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
1279
            mm_counts,
1280
        )
1281

1282
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1283
            prompt_text=dummy_inputs.prompt_text,
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1296
        enable_hf_prompt_update: bool,
1297
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1298
1299
1300
        """
        Apply the HF processor on the prompt text and multi-modal data.

1301
        In addition, return whether prompt updates have been applied
1302
1303
        (for most HF processors, this should be :code:`True`).

1304
        Note:
1305
1306
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1307
            that the prompt corresponds to multi-modal items.
1308
1309
        """
        if isinstance(prompt, str):
1310
            if enable_hf_prompt_update:
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1321
        mm_kwargs = self._apply_hf_processor_mm_only(
1322
            mm_items=mm_items,
1323
1324
1325
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1326
        return prompt_ids, mm_kwargs, False
1327
1328
1329

    def _cached_apply_hf_processor(
        self,
1330
        prompt: Union[str, list[int]],
1331
1332
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1333
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1334
1335
1336
1337
1338
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1339
        model_id = self.info.model_id
1340

1341
1342
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1343
1344
            return self._apply_hf_processor_main(
                prompt=prompt,
1345
1346
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1347
                enable_hf_prompt_update=True,
1348
1349
            )

1350
        mm_maybe_cached_kw_items = {
1351
1352
1353
1354
1355
1356
1357
1358
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1359
1360
1361
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1362
1363
1364
1365
1366
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1367
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1368

1369
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1370
        # so we can't apply prompt updates until the new multimodal
1371
1372
1373
1374
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1375
            is_update_applied,
1376
        ) = self._apply_hf_processor_main(
1377
1378
            prompt=prompt,
            mm_items=mm_missing_data_items,
1379
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1380
            enable_hf_prompt_update=False,
1381
1382
1383
1384
1385
1386
1387
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1388
1389
1390
1391
1392
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1393
1394
1395
1396
1397
1398
1399
1400
1401
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1402
                        kw_item,
1403
1404
1405
1406
                    )

                    mm_missing_next_idx[modality] += 1

1407
                merged_kw_items.append(kw_item)
1408
1409

        if self.enable_sanity_checks:
1410
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1411
1412
1413
1414
1415
1416
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1417
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1418

1419
        return prompt_ids, mm_kwargs, is_update_applied
1420

1421
    def _bind_and_group_updates(
1422
        self,
1423
1424
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1425
        tokenizer = self.info.get_tokenizer()
1426

1427
        it = (update.bind(tokenizer) for update in prompt_updates)
1428
        return dict(full_groupby_modality(it))
1429

1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1446
    def _apply_prompt_updates(
1447
1448
        self,
        token_ids: list[int],
1449
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1450
        mm_item_counts: Mapping[str, int],
1451
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1452
        tokenizer = self.info.get_tokenizer()
1453

1454
        mm_token_matches = {
1455
1456
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1457
        }
1458
1459
        mm_match_counts = {
            modality: len(matches)
1460
            for modality, matches in mm_token_matches.items()
1461
        }
1462
1463
1464
1465
1466
1467
1468
1469
1470

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1471
1472
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1473
        if all(
1474
1475
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1476
        ):  # yapf: disable
1477
            token_ids = self._apply_token_matches(
1478
                token_ids,
1479
                mm_token_matches,
1480
                mm_item_counts,
1481
1482
            )

1483
            text = decode_tokens(tokenizer, token_ids)
1484
1485
            matched_updates = {
                modality: [match._origin for match in token_matches]
1486
1487
                for modality, token_matches in mm_token_matches.items()
            }
1488
        else:
1489
            text = decode_tokens(tokenizer, token_ids)
1490

1491
            mm_text_matches = {
1492
1493
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1494
            }
1495
            text = self._apply_text_matches(
1496
                text,
1497
                mm_text_matches,
1498
                mm_item_counts,
1499
1500
            )

1501
1502
1503
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1504
1505
            matched_updates = {
                modality: [match._origin for match in token_matches]
1506
1507
1508
1509
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1510
            matched_updates,
1511
1512
1513
            token_ids,
            mm_item_counts,
        )
1514
1515

        return token_ids, text, placeholders
1516

1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1540
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1541
        mm_item_counts: Mapping[str, int],
1542
    ) -> None:
1543
1544
1545
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1546
            if len(placeholders) != item_count:
1547
                raise RuntimeError(
1548
                    f"Expected there to be {item_count} prompt updates "
1549
                    f"corresponding to {item_count} {modality} items, but "
1550
                    f"instead found {len(placeholders)} prompt updates! "
1551
                    "Either the prompt text has missing/incorrect tokens for "
1552
1553
1554
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1555
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1556

1557
1558
    def apply(
        self,
1559
        prompt: Union[str, list[int]],
1560
        mm_data: MultiModalDataDict,
1561
        hf_processor_mm_kwargs: Mapping[str, object],
1562
        return_mm_hashes: bool = False,
1563
    ) -> MultiModalInputs:
1564
1565
1566
1567
1568
1569
1570
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1571
        2. Find and update sequences in the token IDs with placeholder tokens.
1572
1573
1574
1575
1576
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1577
        mm_items = self._to_mm_items(mm_data)
1578

1579
        # Create MM hashes to be returned (only used in V1)
1580
1581
1582
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

1583
        if return_mm_hashes:
1584
            model_id = self.info.model_id
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1597
1598
1599
        (
            prompt_ids,
            mm_kwargs,
1600
            is_update_applied,
1601
        ) = self._cached_apply_hf_processor(
1602
            prompt,
1603
1604
1605
            mm_items,
            hf_processor_mm_kwargs,
        )
1606

1607
        unbound_prompt_updates = self._get_prompt_updates(
1608
1609
1610
1611
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1612
1613
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1614

1615
        mm_item_counts = mm_items.get_all_counts()
1616
1617
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1618
        if is_update_applied:
1619
            mm_placeholders = self._find_mm_placeholders(
1620
                mm_prompt_updates,
1621
                prompt_ids,
1622
1623
                mm_item_counts,
            )
1624
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1625

1626
            tokenizer = self.info.get_tokenizer()
1627
            prompt = decode_tokens(tokenizer, prompt_ids)
1628
1629
1630
        else:
            (
                prompt_ids,
1631
                prompt,
1632
                mm_placeholders,
1633
            ) = self._apply_prompt_updates(
1634
                prompt_ids,
1635
                mm_prompt_updates,
1636
                mm_item_counts,
1637
            )
1638
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1639
1640
1641
1642
1643

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1644

1645
        return MultiModalInputs(
1646
            type="multimodal",
1647
            prompt=prompt,
1648
            prompt_token_ids=prompt_ids,
1649
            mm_kwargs=mm_kwargs,
1650
            mm_hashes=mm_hashes,
1651
            mm_placeholders=mm_placeholder_ranges,
1652
        )
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1663
1664
1665
1666
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1667
1668
        raise NotImplementedError

1669
1670
1671
1672
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1673
1674
1675
1676
1677
1678
1679
1680
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1681
1682
1683
1684
1685
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1686
        return_mm_hashes: bool = False,
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1700
            return_mm_hashes,
1701
1702
1703
        )

        tokenizer = self.info.get_tokenizer()
1704
1705
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1706
            decoder_prompt_ids = encode_tokens(tokenizer,
1707
                                               decoder_prompt,
1708
1709
                                               add_special_tokens=False)
        else:
1710
1711
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs