processing.py 53.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import re
3
import sys
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
12
                    TypeVar, Union, cast)
13

14
import torch
15
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
16
from typing_extensions import assert_never
17

18
from vllm.inputs import InputProcessingContext
19
from vllm.jsontree import json_map_leaves, json_reduce_leaves
20
from vllm.logger import init_logger
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
23
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
24

25
from .hasher import MultiModalHasher
26
27
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
28
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32
33

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
34

35
logger = init_logger(__name__)
36
37

_S = TypeVar("_S", str, list[int])
38
39
40

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
41

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
79
80
81
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


104
@dataclass
105
class PromptUpdateDetails(Generic[_S]):
106
    """Details about the token sequence or text that are part of the update."""
107

108
    full: _S
109
    """The full content."""
110

111
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
112
    """
113
114
115
116
117
118
119
    Given :attr:`full`, return a boolean mask of shape `(len(full),)`
    indicating which positions of `full` to assign embeddings to.

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
    :class:`SupportsMultiModal.get_multimodal_embeddings`.
120
121
122
    """

    @staticmethod
123
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
151
152


153
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
154
"""
155
The token sequence or text that are part of the update.
156

157
158
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
159
"""
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
178
class PromptUpdate(ABC):
179
180
181
182
183
184
185
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

186
    target: PromptTarget
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

207

208
@dataclass
209
210
211
212
213
214
215
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
216
        equal to the feature size of the vision encoder after the ``<s>`` token:
217
218
219
220
221

        .. code-block:: python

            PromptInsertion(
                modality="image",
222
                target="<s>",
223
224
225
                insertion="<image>" * image_feature_size,
            )

226
        Insert these tokens at the start of the prompt:
227
228
229
230
231

        .. code-block:: python

            PromptInsertion(
                modality="image",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
277
278
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
303
                replacement=PromptUpdateDetails(
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
321
                replacement=PromptUpdateDetails(
322
323
324
325
326
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
327
328
    """

329
    replacement: PromptUpdateContent = field(repr=False)
330
    """
331
    Given the index of the processed item within :attr:`modality`,
332
    output the token sequence (or text) to replace :attr:`target`.
333

334
335
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
336
337
    """

338
339
340
341
342
343
344
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
345
346


347
348
349
350
351
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
352
    add_special_tokens: Optional[bool] = None,
353
) -> list[int]:
354
355
356
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
357
358


359
360
361
362
363
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
364
    skip_special_tokens: Optional[bool] = None,
365
) -> str:
366
367
368
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
369
370
371
372
373


class _HasModalityAttr(Protocol):
    modality: str

374

375
class _HasModalityProp(Protocol):
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
392
393
394
395
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
396
397
    tokenizer: AnyTokenizer = field(repr=False)

398
399
400
    _text: Optional[str]
    _token_ids: Optional[list[int]]

401
    @staticmethod
402
403
    def from_seq(
        tokenizer: AnyTokenizer,
404
        seq: PromptSeq,
405
    ) -> "_BoundPromptSequence":
406
407
408
409
410
411
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
429
430
431
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
432
433
434
435

        return self._token_ids


436
@dataclass
437
class _BoundPromptContent:
438
    full: _BoundPromptSequence
439
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
440
441


442
@dataclass
443
class BoundPromptUpdate:
444
    """
445
446
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
447
448
    token sequence and text representations.
    """
449
    _origin: PromptUpdate
450
    tokenizer: AnyTokenizer = field(repr=False)
451

452
    def __post_init__(self) -> None:
453
454
455
456
457
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
458
459

    @property
460
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
461
        """The token sequence (or text) to update."""
462
463
464
465
466
467
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
468

469
470
471
472
473
474
475
476
477
478
479
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
480
481
        """
        Given the index of the processed item within :attr:`modality`,
482
        output the token sequence (or text) to update.
483
        """
484
485
        content = self.content
        if callable(content):
486
            cache_key = item_idx
487
488
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
489

490
            content = content(item_idx)
491
492
493
        else:
            cache_key = None

494
495
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
496
497

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
498
499
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
500
                                            is_embed=content.is_embed)
501
502

        if cache_key is not None:
503
            self._content_cache[cache_key] = bound_content
504

505
        return bound_content
506
507


508
509
510
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
511
512


513
514
515
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
516
) -> Generator[_TokenMatch]:
517
518
519
520
521
522
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
523
    match_len = len(match_ids)
524

525
526
    if match_len == 0:
        return
527

528
529
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
530
        end_idx = start_idx + match_len
531

532
533
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
534
535
536
537
538

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
539
540


541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


568
@dataclass(repr=False)
569
class PromptTargetMatch(ABC):
570
    _origin: BoundPromptUpdate
571
572
573

    @property
    def modality(self) -> str:
574
        return self._origin.modality
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


591
@dataclass(repr=False)
592
class _PromptTargetIndexMatch(PromptTargetMatch):
593
594
595
596
597
598
599
600
601
602
603
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


604
@dataclass(repr=False)
605
class _PromptTargetTokenMatch(PromptTargetMatch):
606
607
608
609
610
611
612
613
614
615
616
617
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
618
class _PromptTargetTextMatch(PromptTargetMatch):
619
620
621
622
623
624
625
626
627
628
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

629

630
@dataclass
631
class PlaceholderFeaturesInfo:
632
    modality: str
633
    item_idx: int
634
    start_idx: int
635
    tokens: list[int]
636
    is_embed: Optional[torch.Tensor]
637
638
639

    @property
    def length(self) -> int:
640
        return len(self.tokens)
641
642

    def to_range(self) -> PlaceholderRange:
643
644
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
645
646
647
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
648
            is_embed=self.is_embed,
649
        )
650
651
652
653


def find_token_matches(
    prompt: list[int],
654
    prompt_updates: Sequence[BoundPromptUpdate],
655
) -> Sequence[PromptTargetMatch]:
656
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

673
    return [
674
        match for update in prompt_updates for match in get_matches(update)
675
676
677
678
679
    ]


def find_text_matches(
    prompt: str,
680
    prompt_updates: Sequence[BoundPromptUpdate],
681
) -> Sequence[PromptTargetMatch]:
682
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

699
    return [
700
        match for update in prompt_updates for match in get_matches(update)
701
702
703
704
    ]


def _resolve_matches(
705
    prompt: PromptSeq,
706
707
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
708
    """
709
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
710
    and sort them such that earlier matches take priority over later ones.
711
    """
712
713
    matches = [m for matches in mm_matches.values() for m in matches]

714
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
715

716
    for match in matches:
717
718
719
720
721
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
722

723
            seen_matches[idx] = match
724
725
726
727

    return sorted(matches, key=lambda x: x.start_idx)


728
def _apply_matches(
729
    prompt: _S,
730
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
731
    mm_item_counts: Mapping[str, int],
732
) -> list[_S]:
733
734
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
735
    prev_end_idx = 0
736
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
759

760
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
761

762
        for item_idx in range(item_start_idx, item_end_idx):
763
            content = origin.get_content(item_idx)
764
765
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
766

767
            out_seqs.append(insert_seq)
768

769
770
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
771
772
773

    out_seqs.append(prompt[prev_end_idx:])

774
    return cast(list[_S], out_seqs)
775
776


777
def apply_token_matches(
778
    prompt: list[int],
779
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
780
    mm_item_counts: Mapping[str, int],
781
) -> list[int]:
782
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
783
    if not mm_matches:
784
785
        return prompt

786
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
787
788

    return flatten_2d_lists(token_id_seqs)
789
790


791
def apply_text_matches(
792
    prompt: str,
793
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
794
    mm_item_counts: Mapping[str, int],
795
) -> str:
796
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
797
    if not mm_matches:
798
        return prompt
799

800
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
801
802

    return "".join(texts)
803
804


805
def _iter_placeholders(
806
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
807
    prompt: list[int],
808
    mm_item_counts: Mapping[str, int],
809
) -> Iterable[PlaceholderFeaturesInfo]:
810
811
812
813
814
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
815
    appears earlier in `mm_prompt_updates` takes priority.
816

817
818
    Note that empty matches are ignored.
    """
819
    prompt_len = len(prompt)
820
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
821
822
823
824
825

    start_idx = 0
    while start_idx < prompt_len:
        found = False

826
        for modality, modality_updates in mm_prompt_updates.items():
827
828
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
829
                continue
830

831
832
833
834
835
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
836

837
                if content_len_full == 0 or end_idx_full > prompt_len:
838
839
                    continue

840
                if prompt[start_idx:end_idx_full] == content_tokens_full:
841
842
843
844
845
846
847
848
849
850
851
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
852

853
                    # Exclude overlapping matches
854
                    start_idx = end_idx_full
855
856
857
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
858

859
860
            if found:
                break  # Go back to the outer while loop
861
862
863

        if not found:
            start_idx += 1
864
865


866
def find_mm_placeholders(
867
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
868
869
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
870
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
871
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
872
873
874
    return dict(full_groupby_modality(it))


875
876
877
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


878
879
class ProcessingCache:

880
881
    @staticmethod
    def get_lru_cache(
882
        capacity_gb: float,
883
        value_type: type[_V],
884
885
        *,
        debug: bool = False,
886
887
    ) -> LRUCache[str, _V]:

888
889
890
891
892
893
894
895
896
897
898
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
899
            if isinstance(leaf, torch.Tensor):
900
                return leaf.nbytes
901
902
903

            return sys.getsizeof(leaf)

904
905
906
907
908
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
909
                lambda a, b: a + b,
910
911
912
913
914
915
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
916

917
918
919
920
921
922
923
924
925
926
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
927
928
        super().__init__()

929
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
930
931
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
932

933
934
935
936
937
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
938
939
940
941
942
943

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

944
945
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
946
            logger.debug("ProcessingCache: hit_ratio = %.2f",
947
                         self.debug_cache_hits / total)
948
949
950
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
951
952
953
954
955
956
957

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
958
    ) -> Optional[MultiModalKwargsItem]:
959
960
961
962
963
964
965
966
967
968
969
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

970
971
972
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
973
974
975
976
977
978
979

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

980
981
982
983
984
985
986
987
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
988
        output_kwargs: MultiModalKwargsItem,
989
990
991
992
993
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
994
995
996
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
997
        self._cache[cache_key] = output_kwargs
998
999


1000
class BaseProcessingInfo:
1001
    """Base class to provide the information necessary for data processing."""
1002

1003
1004
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1005

1006
1007
1008
1009
1010
1011
1012
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1013
1014
        return self.ctx.tokenizer

1015
    def get_hf_config(self) -> PretrainedConfig:
1016
1017
        return self.ctx.get_hf_config()

1018
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1019
1020
1021
1022
1023
1024
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
1039

1040
1041

class BaseMultiModalProcessor(ABC, Generic[_I]):
1042
    """
1043
    Abstract base class to process multi-modal inputs to be used in vLLM.
1044
1045

    Not to be confused with :class:`transformers.ProcessorMixin`.
1046
1047
    """

1048
    def __init__(self,
1049
1050
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1051
1052
1053
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1054
1055
        super().__init__()

1056
1057
        self.info = info
        self.dummy_inputs = dummy_inputs
1058
1059
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1060

1061
1062
        self.data_parser = self._get_data_parser()

1063
    def __call__(
1064
        self,
1065
1066
        prompt: str,
        mm_data: MultiModalDataDict,
1067
        hf_processor_mm_kwargs: Mapping[str, object],
1068
    ) -> MultiModalInputs:
1069
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1070

1071
1072
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1073
        Construct a parser to preprocess multi-modal data items
1074
1075
1076
1077
1078
1079
1080
1081
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1082
1083
1084
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1085
1086
1087
1088
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1089
        mm_items = self.data_parser.parse_mm_data(mm_data)
1090
        mm_config = self.info.ctx.get_mm_config()
1091
1092

        for modality, items in mm_items.items():
1093
            limit = mm_config.get_limit_per_prompt(modality)
1094
1095
1096
1097
1098
1099
1100
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1111
    @abstractmethod
1112
    def _get_prompt_updates(
1113
        self,
1114
        mm_items: MultiModalDataItems,
1115
1116
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1117
    ) -> Sequence[PromptUpdate]:
1118
1119
        """
        Given the original multi-modal items for this modality
1120
        and HF-processed data, output the updates to perform.
1121

1122
1123
1124
1125
1126
1127
1128
1129
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1130
1131
        """
        raise NotImplementedError
1132

1133
    def _find_mm_placeholders(
1134
        self,
1135
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1136
        new_token_ids: list[int],
1137
        mm_item_counts: Mapping[str, int],
1138
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1139
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1140
                                    mm_item_counts)
1141

1142
    def _get_hf_mm_data(
1143
        self,
1144
        mm_items: MultiModalDataItems,
1145
1146
1147
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1148

1149
1150
1151
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1152

1153
1154
        return processor_data, passthrough_data

1155
1156
1157
    def _call_hf_processor(
        self,
        prompt: str,
1158
1159
1160
1161
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1162
    ) -> BatchFeature:
1163
1164
1165
1166
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1167
1168
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1169
1170
            dict(text=prompt, **mm_data),
            mm_kwargs,
1171
1172
        )

1173
    def _hf_processor_applies_updates(
1174
1175
1176
1177
1178
1179
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1180
        Return whether the HF processor applies prompt updates.
1181
1182
1183
1184
1185
1186
1187
1188
1189

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1190
    def _apply_hf_processor_text_mm(
1191
        self,
1192
        prompt_text: str,
1193
        mm_items: MultiModalDataItems,
1194
        hf_processor_mm_kwargs: Mapping[str, object],
1195
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1196
        """
1197
1198
        Apply the HF processor on the prompt text and multi-modal data
        together.
1199

1200
        In addition, return whether prompt updates have been applied.
1201
1202
1203
1204
1205
1206
1207
1208
1209
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1210

1211
        prompt_ids, = processed_data.pop("input_ids").tolist()
1212

1213
1214
1215
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1216
        )
1217

1218
        is_update_applied = self._hf_processor_applies_updates(
1219
1220
1221
1222
1223
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1224
        return prompt_ids, mm_kwargs, is_update_applied
1225

1226
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1227
        """
1228
        Apply the HF processor on the prompt text only.
1229

1230
1231
1232
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1233
        """
1234
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1235
1236
1237
1238
1239
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1271
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1272
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1285
        enable_hf_prompt_update: bool,
1286
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1287
1288
1289
        """
        Apply the HF processor on the prompt text and multi-modal data.

1290
        In addition, return whether prompt updates have been applied
1291
1292
        (for most HF processors, this should be :code:`True`).

1293
        Note:
1294
1295
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1296
            that the prompt corresponds to multi-modal items.
1297
1298
        """
        if isinstance(prompt, str):
1299
            if enable_hf_prompt_update:
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1310
        mm_kwargs = self._apply_hf_processor_mm_only(
1311
            mm_items=mm_items,
1312
1313
1314
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1315
        return prompt_ids, mm_kwargs, False
1316
1317
1318

    def _cached_apply_hf_processor(
        self,
1319
        prompt: Union[str, list[int]],
1320
1321
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1322
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1323
1324
1325
1326
1327
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1328
        model_id = self.info.model_id
1329

1330
1331
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1332
1333
            return self._apply_hf_processor_main(
                prompt=prompt,
1334
1335
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1336
                enable_hf_prompt_update=True,
1337
1338
            )

1339
        mm_maybe_cached_kw_items = {
1340
1341
1342
1343
1344
1345
1346
1347
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1348
1349
1350
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1351
1352
1353
1354
1355
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1356
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1357

1358
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1359
        # so we can't apply prompt updates until the new multimodal
1360
1361
1362
1363
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1364
            is_update_applied,
1365
        ) = self._apply_hf_processor_main(
1366
1367
            prompt=prompt,
            mm_items=mm_missing_data_items,
1368
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1369
            enable_hf_prompt_update=False,
1370
1371
1372
1373
1374
1375
1376
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1377
1378
1379
1380
1381
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1382
1383
1384
1385
1386
1387
1388
1389
1390
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1391
                        kw_item,
1392
1393
1394
1395
                    )

                    mm_missing_next_idx[modality] += 1

1396
                merged_kw_items.append(kw_item)
1397
1398

        if self.enable_sanity_checks:
1399
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1400
1401
1402
1403
1404
1405
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1406
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1407

1408
        return prompt_ids, mm_kwargs, is_update_applied
1409

1410
    def _bind_and_group_updates(
1411
        self,
1412
1413
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1414
        tokenizer = self.info.get_tokenizer()
1415

1416
        it = (update.bind(tokenizer) for update in prompt_updates)
1417
        return dict(full_groupby_modality(it))
1418

1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1435
    def _apply_prompt_updates(
1436
1437
        self,
        token_ids: list[int],
1438
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1439
        mm_item_counts: Mapping[str, int],
1440
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1441
        tokenizer = self.info.get_tokenizer()
1442

1443
        mm_token_matches = {
1444
1445
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1446
        }
1447
1448
        mm_match_counts = {
            modality: len(matches)
1449
            for modality, matches in mm_token_matches.items()
1450
        }
1451
1452
1453
1454
1455
1456
1457
1458
1459

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1460
1461
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1462
        if all(
1463
1464
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1465
        ):  # yapf: disable
1466
            token_ids = self._apply_token_matches(
1467
                token_ids,
1468
                mm_token_matches,
1469
                mm_item_counts,
1470
1471
            )

1472
            text = decode_tokens(tokenizer, token_ids)
1473
1474
            matched_updates = {
                modality: [match._origin for match in token_matches]
1475
1476
                for modality, token_matches in mm_token_matches.items()
            }
1477
        else:
1478
            text = decode_tokens(tokenizer, token_ids)
1479

1480
            mm_text_matches = {
1481
1482
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1483
            }
1484
            text = self._apply_text_matches(
1485
                text,
1486
                mm_text_matches,
1487
                mm_item_counts,
1488
1489
            )

1490
1491
1492
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1493
1494
            matched_updates = {
                modality: [match._origin for match in token_matches]
1495
1496
1497
1498
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1499
            matched_updates,
1500
1501
1502
            token_ids,
            mm_item_counts,
        )
1503
1504

        return token_ids, text, placeholders
1505

1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1529
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1530
        mm_item_counts: Mapping[str, int],
1531
    ) -> None:
1532
1533
1534
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1535
            if len(placeholders) != item_count:
1536
                raise RuntimeError(
1537
                    f"Expected there to be {item_count} prompt updates "
1538
                    f"corresponding to {item_count} {modality} items, but "
1539
                    f"instead found {len(placeholders)} prompt updates! "
1540
                    "Either the prompt text has missing/incorrect tokens for "
1541
1542
1543
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1544
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1545

1546
1547
    def apply(
        self,
1548
        prompt: Union[str, list[int]],
1549
        mm_data: MultiModalDataDict,
1550
        hf_processor_mm_kwargs: Mapping[str, object],
1551
        return_mm_hashes: bool = False,
1552
    ) -> MultiModalInputs:
1553
1554
1555
1556
1557
1558
1559
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1560
        2. Find and update sequences in the token IDs with placeholder tokens.
1561
1562
1563
1564
1565
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1566
        mm_items = self._to_mm_items(mm_data)
1567

1568
        # Create MM hashes to be returned (only used in V1)
1569
1570
1571
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

1572
        if return_mm_hashes:
1573
            model_id = self.info.model_id
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1586
1587
1588
        (
            prompt_ids,
            mm_kwargs,
1589
            is_update_applied,
1590
        ) = self._cached_apply_hf_processor(
1591
            prompt,
1592
1593
1594
            mm_items,
            hf_processor_mm_kwargs,
        )
1595

1596
        unbound_prompt_updates = self._get_prompt_updates(
1597
1598
1599
1600
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1601
1602
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1603

1604
        mm_item_counts = mm_items.get_all_counts()
1605
1606
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1607
        if is_update_applied:
1608
            mm_placeholders = self._find_mm_placeholders(
1609
                mm_prompt_updates,
1610
                prompt_ids,
1611
1612
                mm_item_counts,
            )
1613
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1614

1615
            tokenizer = self.info.get_tokenizer()
1616
            prompt = decode_tokens(tokenizer, prompt_ids)
1617
1618
1619
        else:
            (
                prompt_ids,
1620
                prompt,
1621
                mm_placeholders,
1622
            ) = self._apply_prompt_updates(
1623
                prompt_ids,
1624
                mm_prompt_updates,
1625
                mm_item_counts,
1626
            )
1627
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1628
1629
1630
1631
1632

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1633

1634
        return MultiModalInputs(
1635
            type="multimodal",
1636
            prompt=prompt,
1637
            prompt_token_ids=prompt_ids,
1638
            mm_kwargs=mm_kwargs,
1639
            mm_hashes=mm_hashes,
1640
            mm_placeholders=mm_placeholder_ranges,
1641
        )
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1652
1653
1654
1655
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1656
1657
        raise NotImplementedError

1658
1659
1660
1661
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1662
1663
1664
1665
1666
1667
1668
1669
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1670
1671
1672
1673
1674
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1675
        return_mm_hashes: bool = False,
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1689
            return_mm_hashes,
1690
1691
1692
        )

        tokenizer = self.info.get_tokenizer()
1693
1694
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1695
            decoder_prompt_ids = encode_tokens(tokenizer,
1696
                                               decoder_prompt,
1697
1698
                                               add_special_tokens=False)
        else:
1699
1700
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs