processing.py 55.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import json
3
import re
4
import sys
5
from abc import ABC, abstractmethod
6
from collections import defaultdict
7
8
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
9
from dataclasses import dataclass, field
10
from enum import Enum
11
from functools import lru_cache
12
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
13
                    TypeVar, Union, cast)
14

15
import torch
16
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
17
from typing_extensions import assert_never
18

19
from vllm.inputs import InputProcessingContext
20
from vllm.jsontree import json_map_leaves, json_reduce_leaves
21
from vllm.logger import init_logger
22
23
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
24
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
25

26
from .hasher import MultiModalHasher
27
28
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
29
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
30
31
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
32
33
34

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
35

36
logger = init_logger(__name__)
37
38

_S = TypeVar("_S", str, list[int])
39
40
41

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
42

43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
80
81
82
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


105
@dataclass
106
class PromptUpdateDetails(Generic[_S]):
107
    """Details about the token sequence or text that are part of the update."""
108

109
    full: _S
110
    """The full content."""
111

112
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
113
    """
114
115
116
117
118
119
120
    Given :attr:`full`, return a boolean mask of shape `(len(full),)`
    indicating which positions of `full` to assign embeddings to.

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
    :class:`SupportsMultiModal.get_multimodal_embeddings`.
121
122
123
    """

    @staticmethod
124
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
152
153


154
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
155
"""
156
The token sequence or text that are part of the update.
157

158
159
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
160
"""
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
179
class PromptUpdate(ABC):
180
181
182
183
184
185
186
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

187
    target: PromptTarget
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

208

209
@dataclass
210
211
212
213
214
215
216
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

        For each image, insert a number of ``<image>`` feature placeholders
217
        equal to the feature size of the vision encoder after the ``<s>`` token:
218
219
220
221
222

        .. code-block:: python

            PromptInsertion(
                modality="image",
223
                target="<s>",
224
225
226
                insertion="<image>" * image_feature_size,
            )

227
        Insert these tokens at the start of the prompt:
228
229
230
231
232

        .. code-block:: python

            PromptInsertion(
                modality="image",
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                target=PromptIndexTargets.start(),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens after a prefix ``Images:``:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix("Images:"),
                insertion="<image>" * image_feature_size,
            )

        Insert these tokens at the end of the prompt:

        .. code-block:: python

            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.end(),
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                insertion="<image>" * image_feature_size,
            )
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
    Given the index of the processed item within :attr:`modality`,
    output the token sequence (or text) to insert right after :attr:`target`.

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
278
279
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
304
                replacement=PromptUpdateDetails(
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
322
                replacement=PromptUpdateDetails(
323
324
325
326
327
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
328
329
    """

330
    replacement: PromptUpdateContent = field(repr=False)
331
    """
332
    Given the index of the processed item within :attr:`modality`,
333
    output the token sequence (or text) to replace :attr:`target`.
334

335
336
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
337
338
    """

339
340
341
342
343
344
345
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
346
347


348
349
350
351
352
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
353
    add_special_tokens: Optional[bool] = None,
354
) -> list[int]:
355
356
357
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
358
359


360
361
362
363
364
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
365
    skip_special_tokens: Optional[bool] = None,
366
) -> str:
367
368
369
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
370
371
372
373
374


class _HasModalityAttr(Protocol):
    modality: str

375

376
class _HasModalityProp(Protocol):
377

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
393
394
395
396
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
397
398
    tokenizer: AnyTokenizer = field(repr=False)

399
400
401
    _text: Optional[str]
    _token_ids: Optional[list[int]]

402
    @staticmethod
403
404
    def from_seq(
        tokenizer: AnyTokenizer,
405
        seq: PromptSeq,
406
    ) -> "_BoundPromptSequence":
407
408
409
410
411
412
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
430
431
432
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
433
434
435
436

        return self._token_ids


437
@dataclass
438
class _BoundPromptContent:
439
    full: _BoundPromptSequence
440
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
441
442


443
@dataclass
444
class BoundPromptUpdate:
445
    """
446
447
    A :class:`PromptUpdate` bound to a tokenizer to automatically convert
    :attr:`target` and the result of :meth:`get_content` between
448
449
    token sequence and text representations.
    """
450
    _origin: PromptUpdate
451
    tokenizer: AnyTokenizer = field(repr=False)
452

453
    def __post_init__(self) -> None:
454
455
456
457
458
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
459
460

    @property
461
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
462
        """The token sequence (or text) to update."""
463
464
465
466
467
468
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
469

470
471
472
473
474
475
476
477
478
479
480
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
481
482
        """
        Given the index of the processed item within :attr:`modality`,
483
        output the token sequence (or text) to update.
484
        """
485
486
        content = self.content
        if callable(content):
487
            cache_key = item_idx
488
489
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
490

491
            content = content(item_idx)
492
493
494
        else:
            cache_key = None

495
496
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
497
498

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
499
500
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
501
                                            is_embed=content.is_embed)
502
503

        if cache_key is not None:
504
            self._content_cache[cache_key] = bound_content
505

506
        return bound_content
507
508


509
510
511
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
512
513


514
515
516
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
517
) -> Generator[_TokenMatch]:
518
519
520
521
522
523
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
524
    match_len = len(match_ids)
525

526
527
    if match_len == 0:
        return
528

529
530
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
531
        end_idx = start_idx + match_len
532

533
534
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
535
536
537
538
539

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
540
541


542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
    Replace each occurrence of :code:`match_ids` in :code:`token_ids`
    with :code:`new_ids`.

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


569
@dataclass(repr=False)
570
class PromptTargetMatch(ABC):
571
    _origin: BoundPromptUpdate
572
573
574

    @property
    def modality(self) -> str:
575
        return self._origin.modality
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


592
@dataclass(repr=False)
593
class _PromptTargetIndexMatch(PromptTargetMatch):
594
595
596
597
598
599
600
601
602
603
604
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


605
@dataclass(repr=False)
606
class _PromptTargetTokenMatch(PromptTargetMatch):
607
608
609
610
611
612
613
614
615
616
617
618
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
619
class _PromptTargetTextMatch(PromptTargetMatch):
620
621
622
623
624
625
626
627
628
629
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

630

631
@dataclass
632
class PlaceholderFeaturesInfo:
633
    modality: str
634
    item_idx: int
635
    start_idx: int
636
    tokens: list[int]
637
    is_embed: Optional[torch.Tensor]
638
639
640

    @property
    def length(self) -> int:
641
        return len(self.tokens)
642
643

    def to_range(self) -> PlaceholderRange:
644
645
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
646
647
648
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
649
            is_embed=self.is_embed,
650
        )
651
652
653
654


def find_token_matches(
    prompt: list[int],
655
    prompt_updates: Sequence[BoundPromptUpdate],
656
) -> Sequence[PromptTargetMatch]:
657
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

674
    return [
675
        match for update in prompt_updates for match in get_matches(update)
676
677
678
679
680
    ]


def find_text_matches(
    prompt: str,
681
    prompt_updates: Sequence[BoundPromptUpdate],
682
) -> Sequence[PromptTargetMatch]:
683
    """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

700
    return [
701
        match for update in prompt_updates for match in get_matches(update)
702
703
704
705
    ]


def _resolve_matches(
706
    prompt: PromptSeq,
707
708
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
709
    """
710
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
711
    and sort them such that earlier matches take priority over later ones.
712
    """
713
714
    matches = [m for matches in mm_matches.values() for m in matches]

715
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
716

717
    for match in matches:
718
719
720
721
722
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
723

724
            seen_matches[idx] = match
725
726
727
728

    return sorted(matches, key=lambda x: x.start_idx)


729
def _apply_matches(
730
    prompt: _S,
731
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
732
    mm_item_counts: Mapping[str, int],
733
) -> list[_S]:
734
735
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
    out_seqs = list[Union[str, list[int]]]()
736
    prev_end_idx = 0
737
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
760

761
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
762

763
        for item_idx in range(item_start_idx, item_end_idx):
764
            content = origin.get_content(item_idx)
765
766
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
767

768
            out_seqs.append(insert_seq)
769

770
771
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
772
773
774

    out_seqs.append(prompt[prev_end_idx:])

775
    return cast(list[_S], out_seqs)
776
777


778
def apply_token_matches(
779
    prompt: list[int],
780
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
781
    mm_item_counts: Mapping[str, int],
782
) -> list[int]:
783
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
784
    if not mm_matches:
785
786
        return prompt

787
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
788
789

    return flatten_2d_lists(token_id_seqs)
790
791


792
def apply_text_matches(
793
    prompt: str,
794
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
795
    mm_item_counts: Mapping[str, int],
796
) -> str:
797
    """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
798
    if not mm_matches:
799
        return prompt
800

801
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
802
803

    return "".join(texts)
804
805


806
def _iter_placeholders(
807
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
808
    prompt: list[int],
809
    mm_item_counts: Mapping[str, int],
810
) -> Iterable[PlaceholderFeaturesInfo]:
811
812
813
814
815
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
816
    appears earlier in `mm_prompt_updates` takes priority.
817

818
819
    Note that empty matches are ignored.
    """
820
    prompt_len = len(prompt)
821
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
822
823
824
825
826

    start_idx = 0
    while start_idx < prompt_len:
        found = False

827
        for modality, modality_updates in mm_prompt_updates.items():
828
829
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
830
                continue
831

832
833
834
835
836
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
837

838
                if content_len_full == 0 or end_idx_full > prompt_len:
839
840
                    continue

841
                if prompt[start_idx:end_idx_full] == content_tokens_full:
842
843
844
845
846
847
848
849
850
851
852
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
853

854
                    # Exclude overlapping matches
855
                    start_idx = end_idx_full
856
857
858
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
859

860
861
            if found:
                break  # Go back to the outer while loop
862
863
864

        if not found:
            start_idx += 1
865
866


867
def find_mm_placeholders(
868
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
869
870
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
871
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
872
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
873
874
875
    return dict(full_groupby_modality(it))


876
877
878
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


879
880
class ProcessingCache:

881
882
    @staticmethod
    def get_lru_cache(
883
        capacity_gb: float,
884
        value_type: type[_V],
885
886
        *,
        debug: bool = False,
887
888
    ) -> LRUCache[str, _V]:

889
890
891
892
893
894
895
896
897
898
899
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
900
            if isinstance(leaf, torch.Tensor):
901
                return leaf.nbytes
902
903
904

            return sys.getsizeof(leaf)

905
906
907
908
909
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
910
                lambda a, b: a + b,
911
912
913
914
915
916
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
917

918
919
920
921
922
923
924
925
926
927
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
928
929
        super().__init__()

930
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
931
932
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
933

934
935
936
937
938
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
939
940
941
942
943
944

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

945
946
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
947
            logger.debug("ProcessingCache: hit_ratio = %.2f",
948
                         self.debug_cache_hits / total)
949
950
951
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
952
953
954
955
956
957
958

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
959
    ) -> Optional[MultiModalKwargsItem]:
960
961
962
963
964
965
966
967
968
969
970
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

971
972
973
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
974
975
976
977
978
979
980

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

981
982
983
984
985
986
987
988
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
989
        output_kwargs: MultiModalKwargsItem,
990
991
992
993
994
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
995
996
997
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
998
        self._cache[cache_key] = output_kwargs
999
1000


1001
class BaseProcessingInfo:
1002
    """Base class to provide the information necessary for data processing."""
1003

1004
1005
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1006

1007
1008
1009
1010
1011
1012
1013
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1014
1015
        return self.ctx.tokenizer

1016
    def get_hf_config(self) -> PretrainedConfig:
1017
1018
        return self.ctx.get_hf_config()

1019
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1020
1021
1022
1023
1024
1025
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

            allowed_limits[modality] = (user_limit if supported_limit is None
                                        else min(user_limit, supported_limit))

        return allowed_limits

1052
1053

_I = TypeVar("_I", bound=BaseProcessingInfo)
1054

1055
1056

class BaseMultiModalProcessor(ABC, Generic[_I]):
1057
    """
1058
    Abstract base class to process multi-modal inputs to be used in vLLM.
1059
1060

    Not to be confused with :class:`transformers.ProcessorMixin`.
1061
1062
    """

1063
    def __init__(self,
1064
1065
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1066
1067
1068
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
1069
1070
        super().__init__()

1071
1072
        self.info = info
        self.dummy_inputs = dummy_inputs
1073
1074
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
1075

1076
1077
        self.data_parser = self._get_data_parser()

1078
    def __call__(
1079
        self,
1080
1081
        prompt: str,
        mm_data: MultiModalDataDict,
1082
        hf_processor_mm_kwargs: Mapping[str, object],
1083
    ) -> MultiModalInputs:
1084
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1085

1086
1087
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1088
        Construct a parser to preprocess multi-modal data items
1089
1090
1091
1092
1093
1094
1095
1096
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
1097
1098
1099
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1100
1101
1102
1103
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
1104
        mm_items = self.data_parser.parse_mm_data(mm_data)
1105
1106
        supported_mm_limits = self.info.get_supported_mm_limits()
        allowed_mm_limits = self.info.get_allowed_mm_limits()
1107
1108

        for modality, items in mm_items.items():
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
            supported_limit = supported_mm_limits.get(modality, 0)
            allowed_limit = allowed_mm_limits.get(modality, 0)
            num_items = len(items)

            if supported_limit is not None and num_items > supported_limit:
                raise ValueError(
                    f"The model only supports at most {supported_limit} "
                    f"{modality} items, but you passed {num_items} "
                    f"{modality} items in the same prompt.")

            if num_items > allowed_limit:
1120
                raise ValueError(
1121
1122
1123
                    "You set or defaulted to "
                    f"'{json.dumps({modality: allowed_limit})}' in "
                    f"`--limit-mm-per-prompt`, but passed {num_items} "
1124
1125
1126
                    f"{modality} items in the same prompt.")

        return mm_items
1127

1128
1129
1130
1131
1132
1133
1134
1135
1136
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1137
    @abstractmethod
1138
    def _get_prompt_updates(
1139
        self,
1140
        mm_items: MultiModalDataItems,
1141
1142
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1143
    ) -> Sequence[PromptUpdate]:
1144
1145
        """
        Given the original multi-modal items for this modality
1146
        and HF-processed data, output the updates to perform.
1147

1148
1149
1150
1151
1152
1153
1154
1155
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
        in order to construct  :class:`~vllm-multimodal.input.PlaceholderRange`
        for each multi-modal item.
1156
1157
        """
        raise NotImplementedError
1158

1159
    def _find_mm_placeholders(
1160
        self,
1161
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1162
        new_token_ids: list[int],
1163
        mm_item_counts: Mapping[str, int],
1164
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1165
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1166
                                    mm_item_counts)
1167

1168
    def _get_hf_mm_data(
1169
        self,
1170
        mm_items: MultiModalDataItems,
1171
1172
1173
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1174

1175
1176
1177
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1178

1179
1180
        return processor_data, passthrough_data

1181
1182
1183
    def _call_hf_processor(
        self,
        prompt: str,
1184
1185
1186
1187
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1188
    ) -> BatchFeature:
1189
1190
1191
1192
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1193
1194
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1195
1196
            dict(text=prompt, **mm_data),
            mm_kwargs,
1197
1198
        )

1199
    def _hf_processor_applies_updates(
1200
1201
1202
1203
1204
1205
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1206
        Return whether the HF processor applies prompt updates.
1207
1208
1209
1210
1211
1212
1213
1214
1215

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1216
    def _apply_hf_processor_text_mm(
1217
        self,
1218
        prompt_text: str,
1219
        mm_items: MultiModalDataItems,
1220
        hf_processor_mm_kwargs: Mapping[str, object],
1221
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1222
        """
1223
1224
        Apply the HF processor on the prompt text and multi-modal data
        together.
1225

1226
        In addition, return whether prompt updates have been applied.
1227
1228
1229
1230
1231
1232
1233
1234
1235
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1236

1237
        prompt_ids, = processed_data.pop("input_ids").tolist()
1238

1239
1240
1241
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1242
        )
1243

1244
        is_update_applied = self._hf_processor_applies_updates(
1245
1246
1247
1248
1249
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1250
        return prompt_ids, mm_kwargs, is_update_applied
1251

1252
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1253
        """
1254
        Apply the HF processor on the prompt text only.
1255

1256
1257
1258
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1259
        """
1260
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1261
1262
1263
1264
1265
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

1297
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1298
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1311
        enable_hf_prompt_update: bool,
1312
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1313
1314
1315
        """
        Apply the HF processor on the prompt text and multi-modal data.

1316
        In addition, return whether prompt updates have been applied
1317
1318
        (for most HF processors, this should be :code:`True`).

1319
        Note:
1320
1321
            If :code:`enable_hf_prompt_update=False`, we use HF processor
            to perform prompt updates if available; HF processor requires
1322
            that the prompt corresponds to multi-modal items.
1323
1324
        """
        if isinstance(prompt, str):
1325
            if enable_hf_prompt_update:
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1336
        mm_kwargs = self._apply_hf_processor_mm_only(
1337
            mm_items=mm_items,
1338
1339
1340
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1341
        return prompt_ids, mm_kwargs, False
1342
1343
1344

    def _cached_apply_hf_processor(
        self,
1345
        prompt: Union[str, list[int]],
1346
1347
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1348
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1349
1350
1351
1352
1353
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
1354
        model_id = self.info.model_id
1355

1356
1357
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1358
1359
            return self._apply_hf_processor_main(
                prompt=prompt,
1360
1361
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1362
                enable_hf_prompt_update=True,
1363
1364
            )

1365
        mm_maybe_cached_kw_items = {
1366
1367
1368
1369
1370
1371
1372
1373
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1374
1375
1376
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1377
1378
1379
1380
1381
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1382
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1383

1384
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1385
        # so we can't apply prompt updates until the new multimodal
1386
1387
1388
1389
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1390
            is_update_applied,
1391
        ) = self._apply_hf_processor_main(
1392
1393
            prompt=prompt,
            mm_items=mm_missing_data_items,
1394
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1395
            enable_hf_prompt_update=False,
1396
1397
1398
1399
1400
1401
1402
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1403
1404
1405
1406
1407
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1408
1409
1410
1411
1412
1413
1414
1415
1416
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1417
                        kw_item,
1418
1419
1420
1421
                    )

                    mm_missing_next_idx[modality] += 1

1422
                merged_kw_items.append(kw_item)
1423
1424

        if self.enable_sanity_checks:
1425
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1426
1427
1428
1429
1430
1431
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1432
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1433

1434
        return prompt_ids, mm_kwargs, is_update_applied
1435

1436
    def _bind_and_group_updates(
1437
        self,
1438
1439
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1440
        tokenizer = self.info.get_tokenizer()
1441

1442
        it = (update.bind(tokenizer) for update in prompt_updates)
1443
        return dict(full_groupby_modality(it))
1444

1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1461
    def _apply_prompt_updates(
1462
1463
        self,
        token_ids: list[int],
1464
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1465
        mm_item_counts: Mapping[str, int],
1466
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1467
        tokenizer = self.info.get_tokenizer()
1468

1469
        mm_token_matches = {
1470
1471
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1472
        }
1473
1474
        mm_match_counts = {
            modality: len(matches)
1475
            for modality, matches in mm_token_matches.items()
1476
        }
1477
1478
1479
1480
1481
1482
1483
1484
1485

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1486
1487
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1488
        if all(
1489
1490
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1491
        ):  # yapf: disable
1492
            token_ids = self._apply_token_matches(
1493
                token_ids,
1494
                mm_token_matches,
1495
                mm_item_counts,
1496
1497
            )

1498
            text = decode_tokens(tokenizer, token_ids)
1499
1500
            matched_updates = {
                modality: [match._origin for match in token_matches]
1501
1502
                for modality, token_matches in mm_token_matches.items()
            }
1503
        else:
1504
            text = decode_tokens(tokenizer, token_ids)
1505

1506
            mm_text_matches = {
1507
1508
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1509
            }
1510
            text = self._apply_text_matches(
1511
                text,
1512
                mm_text_matches,
1513
                mm_item_counts,
1514
1515
            )

1516
1517
1518
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1519
1520
            matched_updates = {
                modality: [match._origin for match in token_matches]
1521
1522
1523
1524
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1525
            matched_updates,
1526
1527
1528
            token_ids,
            mm_item_counts,
        )
1529
1530

        return token_ids, text, placeholders
1531

1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1555
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1556
        mm_item_counts: Mapping[str, int],
1557
    ) -> None:
1558
1559
1560
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1561
            if len(placeholders) != item_count:
1562
                raise RuntimeError(
1563
                    f"Expected there to be {item_count} prompt updates "
1564
                    f"corresponding to {item_count} {modality} items, but "
1565
                    f"instead found {len(placeholders)} prompt updates! "
1566
                    "Either the prompt text has missing/incorrect tokens for "
1567
1568
1569
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1570
                    "`_call_hf_processor` and `_get_prompt_updates`).")
1571

1572
    def _hash_mm_items(
1573
        self,
1574
        mm_items: MultiModalDataItems,
1575
        hf_processor_mm_kwargs: Mapping[str, object],
1576
1577
    ) -> dict[str, list[str]]:
        """Create MM hashes to be returned (only used in V1)."""
1578

1579
1580
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.
1581
        model_id = self.info.model_id
1582

1583
1584
1585
1586
1587
1588
1589
1590
1591
        return {
            modality: [
                MultiModalHasher.hash_kwargs(model_id=model_id,
                                             **{modality: item},
                                             **hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_items.items()
        }
1592

1593
1594
1595
1596
1597
1598
1599
1600
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        prompt_ids: list[int],
        mm_kwargs: MultiModalKwargs,
        is_update_applied: bool,
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1601
        unbound_prompt_updates = self._get_prompt_updates(
1602
1603
1604
1605
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1606
1607
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1608

1609
        mm_item_counts = mm_items.get_all_counts()
1610
1611
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1612
        if is_update_applied:
1613
            mm_placeholders = self._find_mm_placeholders(
1614
                mm_prompt_updates,
1615
                prompt_ids,
1616
1617
                mm_item_counts,
            )
1618
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1619

1620
            tokenizer = self.info.get_tokenizer()
1621
            prompt = decode_tokens(tokenizer, prompt_ids)
1622
1623
1624
        else:
            (
                prompt_ids,
1625
                prompt,
1626
                mm_placeholders,
1627
            ) = self._apply_prompt_updates(
1628
                prompt_ids,
1629
                mm_prompt_updates,
1630
                mm_item_counts,
1631
            )
1632
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1633

1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
        return prompt_ids, prompt, mm_placeholders

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        return_mm_hashes: bool = False,
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        mm_items = self._to_mm_items(mm_data)

        mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
                     if return_mm_hashes else None)

        (
            prompt_ids,
            mm_kwargs,
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
        )

        prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            prompt_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
            is_update_applied=is_update_applied,
        )

1679
1680
1681
1682
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1683

1684
        return MultiModalInputs(
1685
            type="multimodal",
1686
            prompt=prompt,
1687
            prompt_token_ids=prompt_ids,
1688
            mm_kwargs=mm_kwargs,
1689
            mm_hashes=mm_hashes,
1690
            mm_placeholders=mm_placeholder_ranges,
1691
        )
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1702
1703
1704
1705
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1706
1707
        raise NotImplementedError

1708
1709
1710
1711
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1712
1713
1714
1715
1716
1717
1718
1719
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1720
1721
1722
1723
1724
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1725
        return_mm_hashes: bool = False,
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1739
            return_mm_hashes,
1740
1741
1742
        )

        tokenizer = self.info.get_tokenizer()
1743
1744
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1745
            decoder_prompt_ids = encode_tokens(tokenizer,
1746
                                               decoder_prompt,
1747
1748
                                               add_special_tokens=False)
        else:
1749
1750
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs