processing.py 54.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import re
3
import sys
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from enum import Enum
10
from functools import lru_cache
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
12
                    TypeVar, Union, cast)
13

14
import torch
15
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
16
from typing_extensions import assert_never
17

18
from vllm.inputs import InputProcessingContext
19
from vllm.jsontree import json_map_leaves, json_reduce_leaves
20
from vllm.logger import init_logger
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
23
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
24

25
from .hasher import MultiModalHasher
26
27
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
28
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32
33

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
34

35
logger = init_logger(__name__)
36
37

_S = TypeVar("_S", str, list[int])
38
39
40

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
41

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
79
80
81
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


104
@dataclass
105
class PromptUpdateDetails(Generic[_S]):
106
    """Details about the token sequence or text that are part of the update."""
107

108
    full: _S
109
    """The full content."""
110

111
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
112
    """
113
114
115
116
117
118
119
    Given :attr:`full`, return a boolean mask of shape `(len(full),)`
    indicating which positions of `full` to assign embeddings to.

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
    :class:`SupportsMultiModal.get_multimodal_embeddings`.
120
121
122
    """

    @staticmethod
123
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
151
152


153
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
154
"""
155
The token sequence or text that are part of the update.
156

157
158
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
159
"""
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
178
class PromptUpdate(ABC):
179
180
181
182
183
184
185
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

186
    target: PromptTarget
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

207

208
@dataclass
209
210
211
212
213
214
215
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
216
        equal to the feature size of the vision encoder after the ``<s>`` token:
217
218
219
220
221

        .. code-block:: python

            PromptInsertion(
                modality="image",
222
                target="<s>",
223
224
225
                insertion="<image>" * image_feature_size,
            )

226
        Insert these tokens at the start of the prompt:
227
228
229
230
231

        .. code-block:: python

            PromptInsertion(
                modality="image",
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
277
278
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
303
                replacement=PromptUpdateDetails(
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
321
                replacement=PromptUpdateDetails(
322
323
324
325
326
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
327
328
    """

329
    replacement: PromptUpdateContent = field(repr=False)
330
    """
331
    Given the index of the processed item within :attr:`modality`,
332
    output the token sequence (or text) to replace :attr:`target`.
333

334
335
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
336
337
    """

338
339
340
341
342
343
344
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
345
346


347
348
349
350
351
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
352
    add_special_tokens: Optional[bool] = None,
353
) -> list[int]:
354
355
356
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
357
358


359
360
361
362
363
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
364
    skip_special_tokens: Optional[bool] = None,
365
) -> str:
366
367
368
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
369
370
371
372
373


class _HasModalityAttr(Protocol):
    modality: str

374

375
class _HasModalityProp(Protocol):
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
392
393
394
395
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
396
397
    tokenizer: AnyTokenizer = field(repr=False)

398
399
400
    _text: Optional[str]
    _token_ids: Optional[list[int]]

401
    @staticmethod
402
403
    def from_seq(
        tokenizer: AnyTokenizer,
404
        seq: PromptSeq,
405
    ) -> "_BoundPromptSequence":
406
407
408
409
410
411
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
429
430
431
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
432
433
434
435

        return self._token_ids


436
@dataclass
437
class _BoundPromptContent:
438
    full: _BoundPromptSequence
439
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
440
441


442
@dataclass
443
class BoundPromptUpdate:
444
    """
445
446
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
447
448
    token sequence and text representations.
    """
449
    _origin: PromptUpdate
450
    tokenizer: AnyTokenizer = field(repr=False)
451

452
    def __post_init__(self) -> None:
453
454
455
456
457
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
458
459

    @property
460
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
461
        """The token sequence (or text) to update."""
462
463
464
465
466
467
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
468

469
470
471
472
473
474
475
476
477
478
479
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
480
481
        """
        Given the index of the processed item within :attr:`modality`,
482
        output the token sequence (or text) to update.
483
        """
484
485
        content = self.content
        if callable(content):
486
            cache_key = item_idx
487
488
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
489

490
            content = content(item_idx)
491
492
493
        else:
            cache_key = None

494
495
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
496
497

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
498
499
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
500
                                            is_embed=content.is_embed)
501
502

        if cache_key is not None:
503
            self._content_cache[cache_key] = bound_content
504

505
        return bound_content
506
507


508
509
510
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
511
512


513
514
515
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
516
) -> Generator[_TokenMatch]:
517
518
519
520
521
522
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
523
    match_len = len(match_ids)
524

525
526
    if match_len == 0:
        return
527

528
529
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
530
        end_idx = start_idx + match_len
531

532
533
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
534
535
536
537
538

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
539
540


541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


568
@dataclass(repr=False)
569
class PromptTargetMatch(ABC):
570
    _origin: BoundPromptUpdate
571
572
573

    @property
    def modality(self) -> str:
574
        return self._origin.modality
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


591
@dataclass(repr=False)
592
class _PromptTargetIndexMatch(PromptTargetMatch):
593
594
595
596
597
598
599
600
601
602
603
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


604
@dataclass(repr=False)
605
class _PromptTargetTokenMatch(PromptTargetMatch):
606
607
608
609
610
611
612
613
614
615
616
617
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
618
class _PromptTargetTextMatch(PromptTargetMatch):
619
620
621
622
623
624
625
626
627
628
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

629

630
@dataclass
631
class PlaceholderFeaturesInfo:
632
    modality: str
633
    item_idx: int
634
    start_idx: int
635
    tokens: list[int]
636
    is_embed: Optional[torch.Tensor]
637
638
639

    @property
    def length(self) -> int:
640
        return len(self.tokens)
641
642

    def to_range(self) -> PlaceholderRange:
643
644
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
645
646
647
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
648
            is_embed=self.is_embed,
649
        )
650
651
652
653


def find_token_matches(
    prompt: list[int],
654
    prompt_updates: Sequence[BoundPromptUpdate],
655
) -> Sequence[PromptTargetMatch]:
656
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

673
    return [
674
        match for update in prompt_updates for match in get_matches(update)
675
676
677
678
679
    ]


def find_text_matches(
    prompt: str,
680
    prompt_updates: Sequence[BoundPromptUpdate],
681
) -> Sequence[PromptTargetMatch]:
682
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

699
    return [
700
        match for update in prompt_updates for match in get_matches(update)
701
702
703
704
    ]


def _resolve_matches(
705
    prompt: PromptSeq,
706
707
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
708
    """
709
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
710
    and sort them such that earlier matches take priority over later ones.
711
    """
712
713
    matches = [m for matches in mm_matches.values() for m in matches]

714
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
715

716
    for match in matches:
717
718
719
720
721
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
722

723
            seen_matches[idx] = match
724
725
726
727

    return sorted(matches, key=lambda x: x.start_idx)


728
def _apply_matches(
729
    prompt: _S,
730
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
731
    mm_item_counts: Mapping[str, int],
732
) -> list[_S]:
733
734
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
735
    prev_end_idx = 0
736
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
759

760
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
761

762
        for item_idx in range(item_start_idx, item_end_idx):
763
            content = origin.get_content(item_idx)
764
765
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
766

767
            out_seqs.append(insert_seq)
768

769
770
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
771
772
773

    out_seqs.append(prompt[prev_end_idx:])

774
    return cast(list[_S], out_seqs)
775
776


777
def apply_token_matches(
778
    prompt: list[int],
779
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
780
    mm_item_counts: Mapping[str, int],
781
) -> list[int]:
782
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
783
    if not mm_matches:
784
785
        return prompt

786
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
787
788

    return flatten_2d_lists(token_id_seqs)
789
790


791
def apply_text_matches(
792
    prompt: str,
793
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
794
    mm_item_counts: Mapping[str, int],
795
) -> str:
796
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
797
    if not mm_matches:
798
        return prompt
799

800
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
801
802

    return "".join(texts)
803
804


805
def _iter_placeholders(
806
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
807
    prompt: list[int],
808
    mm_item_counts: Mapping[str, int],
809
) -> Iterable[PlaceholderFeaturesInfo]:
810
811
812
813
814
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
815
    appears earlier in `mm_prompt_updates` takes priority.
816

817
818
    Note that empty matches are ignored.
    """
819
    prompt_len = len(prompt)
820
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
821
822
823
824
825

    start_idx = 0
    while start_idx < prompt_len:
        found = False

826
        for modality, modality_updates in mm_prompt_updates.items():
827
828
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
829
                continue
830

831
832
833
834
835
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
836

837
                if content_len_full == 0 or end_idx_full > prompt_len:
838
839
                    continue

840
                if prompt[start_idx:end_idx_full] == content_tokens_full:
841
842
843
844
845
846
847
848
849
850
851
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
852

853
                    # Exclude overlapping matches
854
                    start_idx = end_idx_full
855
856
857
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
858

859
860
            if found:
                break  # Go back to the outer while loop
861
862
863

        if not found:
            start_idx += 1
864
865


866
def find_mm_placeholders(
867
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
868
869
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
870
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
871
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
872
873
874
    return dict(full_groupby_modality(it))


875
876
877
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


878
879
class ProcessingCache:

880
881
    @staticmethod
    def get_lru_cache(
882
        capacity_gb: float,
883
        value_type: type[_V],
884
885
        *,
        debug: bool = False,
886
887
    ) -> LRUCache[str, _V]:

888
889
890
891
892
893
894
895
896
897
898
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
899
            if isinstance(leaf, torch.Tensor):
900
                return leaf.nbytes
901
902
903

            return sys.getsizeof(leaf)

904
905
906
907
908
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
909
                lambda a, b: a + b,
910
911
912
913
914
915
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
916

917
918
919
920
921
922
923
924
925
926
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
927
928
        super().__init__()

929
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
930
931
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
932

933
934
935
936
937
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
938
939
940
941
942
943

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

944
945
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
946
            logger.debug("ProcessingCache: hit_ratio = %.2f",
947
                         self.debug_cache_hits / total)
948
949
950
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
951
952
953
954
955
956
957

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
958
    ) -> Optional[MultiModalKwargsItem]:
959
960
961
962
963
964
965
966
967
968
969
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

970
971
972
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
973
974
975
976
977
978
979

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

980
981
982
983
984
985
986
987
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
988
        output_kwargs: MultiModalKwargsItem,
989
990
991
992
993
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
994
995
996
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
997
        self._cache[cache_key] = output_kwargs
998
999


1000
class BaseProcessingInfo:
1001
    """Base class to provide the information necessary for data processing."""
1002

1003
1004
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1005

1006
1007
1008
1009
1010
1011
1012
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1013
1014
        return self.ctx.tokenizer

1015
    def get_hf_config(self) -> PretrainedConfig:
1016
1017
        return self.ctx.get_hf_config()

1018
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1019
1020
1021
1022
1023
1024
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

            allowed_limits[modality] = (user_limit if supported_limit is None
                                        else min(user_limit, supported_limit))

        return allowed_limits

1051
1052

_I = TypeVar("_I", bound=BaseProcessingInfo)
1053

1054
1055

class BaseMultiModalProcessor(ABC, Generic[_I]):
1056
    """
1057
    Abstract base class to process multi-modal inputs to be used in vLLM.
1058
1059

    Not to be confused with :class:`transformers.ProcessorMixin`.
1060
1061
    """

1062
    def __init__(self,
1063
1064
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1065
1066
1067
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1068
1069
        super().__init__()

1070
1071
        self.info = info
        self.dummy_inputs = dummy_inputs
1072
1073
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1074

1075
1076
        self.data_parser = self._get_data_parser()

1077
    def __call__(
1078
        self,
1079
1080
        prompt: str,
        mm_data: MultiModalDataDict,
1081
        hf_processor_mm_kwargs: Mapping[str, object],
1082
    ) -> MultiModalInputs:
1083
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1084

1085
1086
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1087
        Construct a parser to preprocess multi-modal data items
1088
1089
1090
1091
1092
1093
1094
1095
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1096
1097
1098
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1099
1100
1101
1102
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1103
        mm_items = self.data_parser.parse_mm_data(mm_data)
1104
1105
        supported_mm_limits = self.info.get_supported_mm_limits()
        allowed_mm_limits = self.info.get_allowed_mm_limits()
1106
1107

        for modality, items in mm_items.items():
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
            supported_limit = supported_mm_limits.get(modality, 0)
            allowed_limit = allowed_mm_limits.get(modality, 0)
            num_items = len(items)

            if supported_limit is not None and num_items > supported_limit:
                raise ValueError(
                    f"The model only supports at most {supported_limit} "
                    f"{modality} items, but you passed {num_items} "
                    f"{modality} items in the same prompt.")

            if num_items > allowed_limit:
1119
                raise ValueError(
1120
1121
                    f"You set or defaulted to {modality}={allowed_limit} "
                    f"in --limit-mm-per-prompt`, but passed {num_items} "
1122
1123
1124
                    f"{modality} items in the same prompt.")

        return mm_items
1125

1126
1127
1128
1129
1130
1131
1132
1133
1134
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1135
    @abstractmethod
1136
    def _get_prompt_updates(
1137
        self,
1138
        mm_items: MultiModalDataItems,
1139
1140
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1141
    ) -> Sequence[PromptUpdate]:
1142
1143
        """
        Given the original multi-modal items for this modality
1144
        and HF-processed data, output the updates to perform.
1145

1146
1147
1148
1149
1150
1151
1152
1153
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1154
1155
        """
        raise NotImplementedError
1156

1157
    def _find_mm_placeholders(
1158
        self,
1159
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1160
        new_token_ids: list[int],
1161
        mm_item_counts: Mapping[str, int],
1162
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1163
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1164
                                    mm_item_counts)
1165

1166
    def _get_hf_mm_data(
1167
        self,
1168
        mm_items: MultiModalDataItems,
1169
1170
1171
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1172

1173
1174
1175
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1176

1177
1178
        return processor_data, passthrough_data

1179
1180
1181
    def _call_hf_processor(
        self,
        prompt: str,
1182
1183
1184
1185
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1186
    ) -> BatchFeature:
1187
1188
1189
1190
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1191
1192
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1193
1194
            dict(text=prompt, **mm_data),
            mm_kwargs,
1195
1196
        )

1197
    def _hf_processor_applies_updates(
1198
1199
1200
1201
1202
1203
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1204
        Return whether the HF processor applies prompt updates.
1205
1206
1207
1208
1209
1210
1211
1212
1213

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1214
    def _apply_hf_processor_text_mm(
1215
        self,
1216
        prompt_text: str,
1217
        mm_items: MultiModalDataItems,
1218
        hf_processor_mm_kwargs: Mapping[str, object],
1219
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1220
        """
1221
1222
        Apply the HF processor on the prompt text and multi-modal data
        together.
1223

1224
        In addition, return whether prompt updates have been applied.
1225
1226
1227
1228
1229
1230
1231
1232
1233
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1234

1235
        prompt_ids, = processed_data.pop("input_ids").tolist()
1236

1237
1238
1239
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1240
        )
1241

1242
        is_update_applied = self._hf_processor_applies_updates(
1243
1244
1245
1246
1247
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1248
        return prompt_ids, mm_kwargs, is_update_applied
1249

1250
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1251
        """
1252
        Apply the HF processor on the prompt text only.
1253

1254
1255
1256
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1257
        """
1258
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1259
1260
1261
1262
1263
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1295
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1296
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1309
        enable_hf_prompt_update: bool,
1310
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1311
1312
1313
        """
        Apply the HF processor on the prompt text and multi-modal data.

1314
        In addition, return whether prompt updates have been applied
1315
1316
        (for most HF processors, this should be :code:`True`).

1317
        Note:
1318
1319
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1320
            that the prompt corresponds to multi-modal items.
1321
1322
        """
        if isinstance(prompt, str):
1323
            if enable_hf_prompt_update:
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1334
        mm_kwargs = self._apply_hf_processor_mm_only(
1335
            mm_items=mm_items,
1336
1337
1338
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1339
        return prompt_ids, mm_kwargs, False
1340
1341
1342

    def _cached_apply_hf_processor(
        self,
1343
        prompt: Union[str, list[int]],
1344
1345
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1346
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1347
1348
1349
1350
1351
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1352
        model_id = self.info.model_id
1353

1354
1355
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1356
1357
            return self._apply_hf_processor_main(
                prompt=prompt,
1358
1359
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1360
                enable_hf_prompt_update=True,
1361
1362
            )

1363
        mm_maybe_cached_kw_items = {
1364
1365
1366
1367
1368
1369
1370
1371
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1372
1373
1374
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1375
1376
1377
1378
1379
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1380
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1381

1382
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1383
        # so we can't apply prompt updates until the new multimodal
1384
1385
1386
1387
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1388
            is_update_applied,
1389
        ) = self._apply_hf_processor_main(
1390
1391
            prompt=prompt,
            mm_items=mm_missing_data_items,
1392
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1393
            enable_hf_prompt_update=False,
1394
1395
1396
1397
1398
1399
1400
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1401
1402
1403
1404
1405
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1406
1407
1408
1409
1410
1411
1412
1413
1414
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1415
                        kw_item,
1416
1417
1418
1419
                    )

                    mm_missing_next_idx[modality] += 1

1420
                merged_kw_items.append(kw_item)
1421
1422

        if self.enable_sanity_checks:
1423
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1424
1425
1426
1427
1428
1429
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1430
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1431

1432
        return prompt_ids, mm_kwargs, is_update_applied
1433

1434
    def _bind_and_group_updates(
1435
        self,
1436
1437
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1438
        tokenizer = self.info.get_tokenizer()
1439

1440
        it = (update.bind(tokenizer) for update in prompt_updates)
1441
        return dict(full_groupby_modality(it))
1442

1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1459
    def _apply_prompt_updates(
1460
1461
        self,
        token_ids: list[int],
1462
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1463
        mm_item_counts: Mapping[str, int],
1464
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1465
        tokenizer = self.info.get_tokenizer()
1466

1467
        mm_token_matches = {
1468
1469
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1470
        }
1471
1472
        mm_match_counts = {
            modality: len(matches)
1473
            for modality, matches in mm_token_matches.items()
1474
        }
1475
1476
1477
1478
1479
1480
1481
1482
1483

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1484
1485
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1486
        if all(
1487
1488
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1489
        ):  # yapf: disable
1490
            token_ids = self._apply_token_matches(
1491
                token_ids,
1492
                mm_token_matches,
1493
                mm_item_counts,
1494
1495
            )

1496
            text = decode_tokens(tokenizer, token_ids)
1497
1498
            matched_updates = {
                modality: [match._origin for match in token_matches]
1499
1500
                for modality, token_matches in mm_token_matches.items()
            }
1501
        else:
1502
            text = decode_tokens(tokenizer, token_ids)
1503

1504
            mm_text_matches = {
1505
1506
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1507
            }
1508
            text = self._apply_text_matches(
1509
                text,
1510
                mm_text_matches,
1511
                mm_item_counts,
1512
1513
            )

1514
1515
1516
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1517
1518
            matched_updates = {
                modality: [match._origin for match in token_matches]
1519
1520
1521
1522
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1523
            matched_updates,
1524
1525
1526
            token_ids,
            mm_item_counts,
        )
1527
1528

        return token_ids, text, placeholders
1529

1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1553
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1554
        mm_item_counts: Mapping[str, int],
1555
    ) -> None:
1556
1557
1558
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1559
            if len(placeholders) != item_count:
1560
                raise RuntimeError(
1561
                    f"Expected there to be {item_count} prompt updates "
1562
                    f"corresponding to {item_count} {modality} items, but "
1563
                    f"instead found {len(placeholders)} prompt updates! "
1564
                    "Either the prompt text has missing/incorrect tokens for "
1565
1566
1567
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1568
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1569

1570
1571
    def apply(
        self,
1572
        prompt: Union[str, list[int]],
1573
        mm_data: MultiModalDataDict,
1574
        hf_processor_mm_kwargs: Mapping[str, object],
1575
        return_mm_hashes: bool = False,
1576
    ) -> MultiModalInputs:
1577
1578
1579
1580
1581
1582
1583
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
1584
        2. Find and update sequences in the token IDs with placeholder tokens.
1585
1586
1587
1588
1589
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1590
        mm_items = self._to_mm_items(mm_data)
1591

1592
        # Create MM hashes to be returned (only used in V1)
1593
1594
1595
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

1596
        if return_mm_hashes:
1597
            model_id = self.info.model_id
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1610
1611
1612
        (
            prompt_ids,
            mm_kwargs,
1613
            is_update_applied,
1614
        ) = self._cached_apply_hf_processor(
1615
            prompt,
1616
1617
1618
            mm_items,
            hf_processor_mm_kwargs,
        )
1619

1620
        unbound_prompt_updates = self._get_prompt_updates(
1621
1622
1623
1624
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1625
1626
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1627

1628
        mm_item_counts = mm_items.get_all_counts()
1629
1630
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1631
        if is_update_applied:
1632
            mm_placeholders = self._find_mm_placeholders(
1633
                mm_prompt_updates,
1634
                prompt_ids,
1635
1636
                mm_item_counts,
            )
1637
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1638

1639
            tokenizer = self.info.get_tokenizer()
1640
            prompt = decode_tokens(tokenizer, prompt_ids)
1641
1642
1643
        else:
            (
                prompt_ids,
1644
                prompt,
1645
                mm_placeholders,
1646
            ) = self._apply_prompt_updates(
1647
                prompt_ids,
1648
                mm_prompt_updates,
1649
                mm_item_counts,
1650
            )
1651
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1652
1653
1654
1655
1656

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1657

1658
        return MultiModalInputs(
1659
            type="multimodal",
1660
            prompt=prompt,
1661
            prompt_token_ids=prompt_ids,
1662
            mm_kwargs=mm_kwargs,
1663
            mm_hashes=mm_hashes,
1664
            mm_placeholders=mm_placeholder_ranges,
1665
        )
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1676
1677
1678
1679
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1680
1681
        raise NotImplementedError

1682
1683
1684
1685
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1686
1687
1688
1689
1690
1691
1692
1693
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1694
1695
1696
1697
1698
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1699
        return_mm_hashes: bool = False,
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1713
            return_mm_hashes,
1714
1715
1716
        )

        tokenizer = self.info.get_tokenizer()
1717
1718
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1719
            decoder_prompt_ids = encode_tokens(tokenizer,
1720
                                               decoder_prompt,
1721
1722
                                               add_special_tokens=False)
        else:
1723
1724
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs