processor.py 63.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections import defaultdict
5
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
6
from dataclasses import dataclass, field, replace
7
from enum import Enum
8
from functools import lru_cache
9
10
11
12
13
from typing import (
    TYPE_CHECKING,
    Generic,
    NamedTuple,
    Protocol,
14
    TypeAlias,
15
16
    cast,
)
17

18
import regex as re
19
import torch
20
from typing_extensions import TypeVar, assert_never
21

22
from vllm.logger import init_logger
23
from vllm.tokenizers import TokenizerLike
24
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
25

26
27
from ..hasher import MultiModalHasher
from ..inputs import (
28
29
30
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalFieldConfig,
31
    MultiModalHashes,
32
33
34
35
36
37
38
    MultiModalInputs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalKwargsOptionalItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
39
from ..parse import (
40
41
42
43
44
    DictEmbeddingItems,
    EmbeddingItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
45
46
47
48
49
from .context import (
    BaseProcessingInfo,
    get_current_request_id,
    timed_preprocessor_operation,
)
50
from .dummy_inputs import BaseDummyInputsBuilder
51
52

if TYPE_CHECKING:
53
54
    from transformers.feature_extraction_utils import BatchFeature

55
    from ..cache import BaseMultiModalProcessorCache
56
57
58
59
else:
    BatchFeature = object

    BaseMultiModalProcessorCache = object
60

61
logger = init_logger(__name__)
62
63

_S = TypeVar("_S", str, list[int])
64

65

66
PromptSeq: TypeAlias = str | list[int]
67
"""A token sequence (list of token IDs) or text."""
68

69

70
71
@lru_cache(maxsize=2048)
def _cached_encode(
72
    tokenizer: TokenizerLike,
73
74
    text: str,
    *,
75
    add_special_tokens: bool = True,
76
) -> list[int]:
77
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
78
79
80
81


@lru_cache(maxsize=2048)
def _cached_decode(
82
    tokenizer: TokenizerLike,
83
84
    token_ids: tuple[int, ...],
    *,
85
    skip_special_tokens: bool = False,
86
) -> str:
87
    return tokenizer.decode(list(token_ids), skip_special_tokens=skip_special_tokens)
88
89


90
91
92
93
94
95
def _seq2text(
    tokenizer: TokenizerLike | None,
    seq: PromptSeq,
    *,
    use_cache: bool = True,
) -> str:
96
97
98
    if isinstance(seq, str):
        return seq

99
100
101
102
    if tokenizer is None:
        raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`")

    if not use_cache:
103
        return tokenizer.decode(seq)
104

105
106
107
    return _cached_decode(tokenizer, tuple(seq))


108
109
110
111
112
113
def _seq2tokens(
    tokenizer: TokenizerLike | None,
    seq: PromptSeq,
    *,
    use_cache: bool = True,
) -> list[int]:
114
    if isinstance(seq, str):
115
116
117
118
        if tokenizer is None:
            raise ValueError("You cannot encode text when `skip_tokenizer_init=True`")

        if not use_cache:
119
            return tokenizer.encode(seq, add_special_tokens=False)
120

121
122
123
124
125
        return _cached_encode(tokenizer, seq, add_special_tokens=False)

    return seq


126
127
128
class _GetMatchIndex(Protocol):
    def __call__(
        self,
129
        tokenizer: TokenizerLike | None,
130
131
        prompt: PromptSeq,
        start_idx: int = 0,
132
    ) -> int | None: ...
133
134


135
136
137
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
138

139
    get_match_index: _GetMatchIndex
140
141
142
143
144
145
146
147
148
149


class PromptIndexTargets:
    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
150
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
151
152
153
154
155
156
157
158

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
159
            tokenizer: TokenizerLike | None,
160
            prompt: PromptSeq,
161
            start_idx: int = 0,
162
        ) -> int | None:
163
164
165
            if start_idx != 0:
                return None

166
167
168
            prefix = seq

            if isinstance(prompt, str):
169
170
                # Make both `str`
                prefix = _seq2text(tokenizer, prefix, use_cache=False)
171
            else:
172
173
                # Make both `list[int]`
                prefix = _seq2tokens(tokenizer, prefix, use_cache=False)
174
175
176
177
178
179
180
181
182
183
184
185
186

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
187
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
188
189


190
UpdateTarget: TypeAlias = PromptSeq | PromptIndex
191
192
193
194
"""
The token sequence or text to update.
"""

195
PromptUpdateTarget: TypeAlias = Callable[[int], UpdateTarget] | UpdateTarget
196
197
198
199
200
201
202
203
204
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""

205

206
@dataclass
207
class PromptUpdateDetails(Generic[_S]):
208
    """Details about the token sequence or text that are part of the update."""
209

210
    full: _S
211
    """The full content."""
212

213
    is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None
214
    """
215
216
217
    Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
    return a boolean mask of shape `(len(full),)` indicating which positions
    of `full` to assign embeddings to.
218
219
220
221

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
222
    [`SupportsMultiModal.embed_multimodal`][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
223
224
225
    """

    @staticmethod
226
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
227
228
229
230
231
232
233
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":
234
235
        def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
            embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False)
236
            token_ids = _seq2tokens(tokenizer, full)
237
238

            return torch.isin(
239
                torch.tensor(token_ids),
240
241
242
243
244
245
246
247
248
249
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
250
        def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
251
252
253
254
255
            token_ids = _seq2tokens(tokenizer, full)

            return torch.tensor(token_ids) == embed_token_id

        return PromptUpdateDetails(full=seq, is_embed=is_embed)
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    @staticmethod
    def select_token_ids(
        seq: _S,
        embed_token_ids: list[int],
    ) -> "PromptUpdateDetails[_S]":
        def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
            token_ids = _seq2tokens(tokenizer, full)

            return torch.isin(
                torch.tensor(token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

272

273
PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails
274
"""
275
The token sequence or text that are part of the update.
276

277
If only part of the content corresponds to feature placeholders, you can
278
279
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
280
"""
281

282
PromptUpdateContent: TypeAlias = Callable[[int], PromptUpdateInfo] | PromptUpdateInfo
283
"""
284
285
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
286
287
288
289
290
291
292
293
294
295
296
297
298
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
299
class PromptUpdate(ABC):
300
301
302
303
304
305
306
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

307
    target: PromptUpdateTarget
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

322
    def _resolve_target(self, item_idx: int) -> UpdateTarget:
323
324
325
326
        target = self.target
        if callable(target):
            target = target(item_idx)

327
        return target
328

329
    def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
330
331
332
333
334
335
336
        content = self.content
        if callable(content):
            content = content(item_idx)

        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

337
        return content
338

339
    def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
340
341
342
343
344
345
346
347
348
        """
        Given the index of the processed item within
        [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
        output a copy of this object with its lazy attributes resolved.
        """
        return ResolvedPromptUpdate(
            modality=self.modality,
            item_idx=item_idx,
            mode=self.mode,
349
350
            target=self._resolve_target(item_idx),
            content=self._resolve_content(item_idx),
351
352
        )

353

354
@dataclass
355
356
357
358
359
360
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

361
362
    For each image, insert a number of `<image>` feature placeholders
    equal to the feature size of the vision encoder after the `<s>` token:
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

    ```python
    PromptInsertion(
        modality="image",
        target="<s>",
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the start of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.start(),
        insertion="<image>" * image_feature_size,
    )
    ```

382
    Insert these tokens after a prefix `Images:`:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.prefix("Images:"),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the end of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.end(),
        insertion="<image>" * image_feature_size,
    )
    ```
401
402
403
404
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
405
406
407
408
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to insert right after
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
425
426
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
427
428
429

    Example:

430
431
    For each image, replace one `<image>` input placeholder in the prompt
    with a number of `<image>` feature placeholders
432
433
434
435
436
437
438
439
440
441
    equal to the feature size of the vision encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement="<image>" * image_feature_size,
    )
    ```

442
443
    As above, but further pad the feature placeholders with `<image_bos>`
    and `<image_eos>`, which are not supposed to be passed to the vision
444
445
446
447
448
449
450
    encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement=PromptUpdateDetails(
451
452
453
454
455
456
457
            full="".join(
                [
                    "<image_bos>",
                    "<image>" * image_feature_size,
                    "<image_eos>",
                ]
            ),
458
459
460
461
462
463
464
465
466
467
468
469
470
            features="<image>" * image_feature_size,
        ),
    )
    ```

    To avoid unnecessary tokenization during prompt replacement,
    we recommended passing token sequences instead of text:

    ```python
    PromptReplacement(
        modality="image",
        target=[image_token_id],
        replacement=PromptUpdateDetails(
471
472
473
            full=(
                [image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id]
            ),
474
475
476
477
            features=[image_token_id] * image_feature_size,
        ),
    )
    ```
478
479
    """

480
    replacement: PromptUpdateContent = field(repr=False)
481
    """
482
483
484
485
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to replace
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
486

487
488
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
489
490
    """

491
492
493
494
495
496
497
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
498
499


500
501
502
class _HasModalityAttr(Protocol):
    modality: str

503

504
505
class _HasModalityProp(Protocol):
    @property
506
    def modality(self) -> str: ...
507
508


509
_M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
510
511
512


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
513
514
    """
    Convenience function to apply
515
    [`full_groupby`][vllm.utils.collection_utils.full_groupby]
516
517
    based on modality.
    """
518
519
520
    return full_groupby(values, key=lambda x: x.modality)


521
522
523
524
525
526
527
class PromptTargetMatch(NamedTuple):
    start_idx: int
    end_idx: int


@dataclass(frozen=True)
class ResolvedPromptUpdate:
528
    """
529
530
    A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
    lazy attributes resolved, apart from those related to tokenization.
531
    """
532

533
534
    modality: str
    """The modality for which the update is made."""
535

536
537
    item_idx: int
    """The index within `modality` of the item this update pertains to."""
538

539
540
    mode: UpdateMode
    """Defines how to update the prompt."""
541

542
    target: UpdateTarget
543
    """The token sequence (or text) to update."""
544

545
    content: PromptUpdateDetails = field(repr=False)
546
    """The placeholder tokens that are part of the update."""
547

548
549
550
    def iter_token_matches(
        self,
        prompt: list[int],
551
        tokenizer: TokenizerLike | None,
552
553
554
555
556
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
557

558
559
560
561
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
562

563
            return
564

565
566
        target_token_ids = _seq2tokens(tokenizer, target)

567
        for match in iter_token_matches(prompt, target_token_ids, start_idx=start_idx):
568
            yield PromptTargetMatch(match.start_idx, match.end_idx)
569

570
571
572
    def iter_text_matches(
        self,
        prompt: str,
573
        tokenizer: TokenizerLike | None,
574
575
576
577
578
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
579

580
581
582
583
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
584

585
            return
586

587
588
        target_text = _seq2text(tokenizer, target)

589
        for match in re.finditer(re.escape(target_text), prompt, pos=start_idx):
590
591
592
593
            yield PromptTargetMatch(match.start(), match.end())

    def iter_matches(
        self,
594
        prompt: list[int] | str,
595
        tokenizer: TokenizerLike | None,
596
597
598
599
600
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        if isinstance(prompt, str):
601
            return self.iter_text_matches(prompt, tokenizer, start_idx=start_idx)
602
603

        return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
604

605
606
607
608
609
610
611
612
613
    def with_target(self, target: UpdateTarget):
        return replace(self, target=target)

    def with_content(self, content: PromptUpdateInfo):
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

        return replace(self, content=content)

614

615
616
617
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
618
619


620
621
622
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
623
624
    *,
    start_idx: int = 0,
625
) -> Generator[_TokenMatch]:
626
    """
627
    Yield each occurrence of `match_ids` in `token_ids`.
628
629
630
631

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
632
    match_len = len(match_ids)
633

634
635
    if match_len == 0:
        return
636

637
    while start_idx < prompt_len - match_len + 1:
638
        end_idx = start_idx + match_len
639

640
641
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
642
643
644
645
646

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
647
648


649
650
651
652
653
654
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
655
656
    Replace each occurrence of `match_ids` in `token_ids`
    with `new_ids`.
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


676
@dataclass
677
class PlaceholderFeaturesInfo:
678
    modality: str
679
    item_idx: int
680
    start_idx: int
681
    tokens: list[int]
682
    is_embed: torch.Tensor | None
683
684
685

    @property
    def length(self) -> int:
686
        return len(self.tokens)
687
688

    def to_range(self) -> PlaceholderRange:
689
690
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
691
692
693
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
694
            is_embed=self.is_embed,
695
        )
696
697


698
_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
699
700


701
702
703
def _find_matches(
    prompt: _S,
    mm_prompt_updates: "MultiModalPromptUpdates",
704
    tokenizer: TokenizerLike | None,
705
706
707
    *,
    prev_end_idx: int = 0,
    current_result: "MultiModalPromptUpdatesApplyResult",
708
709
) -> tuple[UpdateMode | None, list[_MatchToApply]]:
    mode: UpdateMode | None = None
710
711
712
713
714
715
716
717
718
719
720
721
    mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()

    for modality, modality_updates in mm_prompt_updates.items():
        for item_idx, item_updates in enumerate(modality_updates):
            if current_result[modality][item_idx] is not None:
                continue  # Updates have already been applied for this item

            for update_idx, update in enumerate(item_updates):
                if (modality, item_idx) in mm_matches:
                    break  # Already found a match for this item

                for match in update.iter_matches(
722
723
724
                    prompt,
                    tokenizer,
                    start_idx=prev_end_idx,
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
                ):
                    # All matches should share the same mode
                    if mode is None:
                        mode = update.mode
                    elif mode != update.mode:
                        continue

                    mm_matches[(modality, item_idx)] = match, update_idx
                    break  # Get only the first valid match per item

    # Prioritize earlier matches
    matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])

    # To avoid conflicts, only replace one non-empty item at a time
    if mode == UpdateMode.REPLACE:
        matches_to_apply_ = list[_MatchToApply]()
        has_non_empty_matches = False

        for item in matches_to_apply:
            _, (match, _) = item
            if match.start_idx == match.end_idx:
                matches_to_apply_.append(item)
            elif not has_non_empty_matches:
                has_non_empty_matches = True
                matches_to_apply_.append(item)

        matches_to_apply = matches_to_apply_

    return mode, matches_to_apply
754
755


756
757
758
759
760
761
762
763
764
765
def _all_items_found(
    mm_item_counts: dict[str, int],
    mm_found_counts: dict[str, int],
) -> bool:
    return all(
        item_idx >= mm_item_counts[modality]
        for modality, item_idx in mm_found_counts.items()
    )


766
def _apply_matches(
767
    prompt: _S,
768
    mm_prompt_updates: "MultiModalPromptUpdates",
769
    tokenizer: TokenizerLike | None,
770
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
771
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
772

773
    out_seqs = list[str | list[int]]()
774
    out_result: MultiModalPromptUpdatesApplyResult = {
775
        m: [None] * len(items) for m, items in mm_prompt_updates.items()
776
    }
777

778
    # Early exit if no items to find
779
780
781
782
783
784
    mm_found_counts = {
        m: sum(r is not None for r in res) for m, res in out_result.items()
    }
    if _all_items_found(mm_item_counts, mm_found_counts):
        return [prompt], out_result

785
786
    prev_end_idx = 0
    while True:
787
788
789
790
791
792
793
        mode, matches_to_apply = _find_matches(
            prompt,
            mm_prompt_updates,
            tokenizer,
            prev_end_idx=prev_end_idx,
            current_result=out_result,
        )
794

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        if mode is None:
            break  # No more matches to find

        for (modality, item_idx), (match, update_idx) in matches_to_apply:
            matched_update = mm_prompt_updates[modality][item_idx][update_idx]
            matched_content = matched_update.content.full

            if mode == UpdateMode.INSERT:
                end_idx_to_insert = match.end_idx
            elif mode == UpdateMode.REPLACE:
                end_idx_to_insert = match.start_idx
            else:
                assert_never(mode)

            out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
            out_seqs.append(
                _seq2text(tokenizer, matched_content)
                if isinstance(prompt, str)
                else _seq2tokens(tokenizer, matched_content)
            )
            out_result[modality][item_idx] = update_idx

            # Exclude overlapping matches
            prev_end_idx = match.end_idx

        # Early exit if all items found
        mm_found_counts = {
            m: sum(r is not None for r in res) for m, res in out_result.items()
        }
        if _all_items_found(mm_item_counts, mm_found_counts):
            break
826
827
828

    out_seqs.append(prompt[prev_end_idx:])

829
    return cast(list[_S], out_seqs), out_result
830
831


832
def apply_token_matches(
833
    prompt: list[int],
834
    mm_prompt_updates: "MultiModalPromptUpdates",
835
    tokenizer: TokenizerLike | None,
836
837
838
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
839

840
841
842
843
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
844
    token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
845

846
    return flatten_2d_lists(token_id_seqs), result
847
848


849
def apply_text_matches(
850
    prompt: str,
851
    mm_prompt_updates: "MultiModalPromptUpdates",
852
    tokenizer: TokenizerLike | None,
853
854
855
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
856

857
858
859
860
861
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
    texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
862

863
    return "".join(texts), result
864
865


866
def _iter_placeholders(
867
    prompt: list[int],
868
    mm_prompt_updates: "MultiModalPromptUpdates",
869
    tokenizer: TokenizerLike | None,
870
) -> Iterable[PlaceholderFeaturesInfo]:
871
    """
872
    Yield each set of placeholder tokens found in `prompt`.
873
874
875

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
876
    appears earlier in `mm_prompt_updates` takes priority.
877

878
879
    Note that empty matches are ignored.
    """
880
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
881
    item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
882

883
884
    if _all_items_found(mm_item_counts, item_idx_by_modality):
        return
885

886
    prompt_len = len(prompt)
887
    start_idx = 0
888

889
890
891
    while start_idx < prompt_len:
        found = False

892
        for modality, modality_updates in mm_prompt_updates.items():
893
894
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
895
                continue
896

897
898
            for update in modality_updates[item_idx]:
                content = update.content
899
                content_tokens_full = _seq2tokens(tokenizer, content.full)
900
901
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
902

903
                if content_len_full == 0 or end_idx_full > prompt_len:
904
905
                    continue

906
                if prompt[start_idx:end_idx_full] == content_tokens_full:
907
908
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
909
                        content_is_embed = content_is_embed(tokenizer, content.full)
910
911
912
913
914
915
916
917

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
918

919
                    # Exclude overlapping matches
920
                    start_idx = end_idx_full
921
922
923
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
924

925
            if found:
926
927
928
                if _all_items_found(mm_item_counts, item_idx_by_modality):
                    return

929
                break  # Go back to the outer while loop
930
931
932

        if not found:
            start_idx += 1
933
934


935
936
def find_mm_placeholders(
    prompt: list[int],
937
    mm_prompt_updates: "MultiModalPromptUpdates",
938
    tokenizer: TokenizerLike | None,
939
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
940
    it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
941
942
943
    return dict(full_groupby_modality(it))


944
945
946
MultiModalIsCached = dict[str, list[bool]]
"""
A collection of the `is_cached` flag for each item, with a similar structure as
947
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
948
949
"""

950
MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
951
952
953
954
955
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""

956
MultiModalPromptUpdatesApplyResult = Mapping[str, list[int | None]]
957
958
959
960
961
962
963
"""
For an item `MultiModalPromptUpdates[k][i]`,
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
`ResolvedPromptUpdate` instances have been applied.
"""

964
965
_I = TypeVar("_I", bound=BaseProcessingInfo)

966
967

class MultiModalProcessingInfo(NamedTuple):
968
    kwargs: MultiModalKwargsOptionalItems
969
    hashes: MultiModalHashes
970
971
    prompt_updates: MultiModalPromptUpdates

972
973

class BaseMultiModalProcessor(ABC, Generic[_I]):
974
    """
975
    Abstract base class to process multi-modal inputs to be used in vLLM.
976

977
    Not to be confused with `transformers.ProcessorMixin`.
978
979
    """

980
981
982
983
984
    def __init__(
        self,
        info: _I,
        dummy_inputs: "BaseDummyInputsBuilder[_I]",
        *,
985
        cache: BaseMultiModalProcessorCache | None = None,
986
    ) -> None:
987
988
        super().__init__()

989
990
        self.info = info
        self.dummy_inputs = dummy_inputs
991
        self.cache = cache
992

993
994
        self.data_parser = self._get_data_parser()

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        # Avoid unnecessary recomputation
        self._supported_mm_limits = self.info.get_supported_mm_limits()
        self._allowed_mm_limits = self.info.get_allowed_mm_limits()

    @property
    def supported_mm_limits(self):
        return self._supported_mm_limits

    @property
    def allowed_mm_limits(self):
        return self._allowed_mm_limits

1007
    def __call__(
1008
        self,
1009
1010
        prompt: str,
        mm_data: MultiModalDataDict,
1011
        hf_processor_mm_kwargs: Mapping[str, object],
1012
        *,
1013
        mm_uuids: MultiModalUUIDDict | None = None,
1014
    ) -> MultiModalInputs:
1015
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
1016

1017
1018
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1019
        Construct a parser to preprocess multi-modal data items
1020
1021
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1022
1023

        You can support additional modalities by creating a subclass
1024
1025
        of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
        that has additional subparsers.
1026
        """
1027
1028
1029
1030
1031
1032
1033
1034
1035
        # Get expected hidden size for embedding validation if mm_embeds enabled
        # This validates hidden dimensions to prevent vulnerabilities: embeddings
        # with correct ndim but wrong shape could cause crashes at inference time
        mm_config = self.info.ctx.model_config.get_multimodal_config()
        expected_hidden_size = None
        if mm_config.enable_mm_embeds:
            expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size()

        return MultiModalDataParser(expected_hidden_size=expected_hidden_size)
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    def validate_num_items(
        self,
        modality: str,
        num_items: int,
    ) -> None:
        supported_limit = self.supported_mm_limits.get(modality, 0)
        allowed_limit = self.allowed_mm_limits.get(modality, 0)

        if supported_limit is None:
            supported_limit = allowed_limit

        limit = min(supported_limit, allowed_limit)

        if num_items > limit:
1051
            msg = f"At most {limit} {modality}(s) may be provided in one prompt."
1052
1053
1054
1055
1056
1057

            if num_items <= supported_limit:
                msg += " Set `--limit-mm-per-prompt` to increase this limit."

            raise ValueError(msg)

1058
    def _to_mm_items(
1059
1060
1061
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1062
        """
1063
1064
1065
1066
1067
        Normalize
        [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
        to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1068
        """
1069
        mm_items = self.data_parser.parse_mm_data(mm_data)
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079

        mm_config = self.info.ctx.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            for modality, items in mm_items.items():
                if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
                    raise ValueError(
                        f"You must set `--enable-mm-embeds` to input "
                        f"`{modality}_embeds`"
                    )

1080
        for modality, items in mm_items.items():
1081
            self.validate_num_items(modality, len(items))
1082
1083

        return mm_items
1084

1085
1086
1087
    @abstractmethod
    def _get_mm_fields_config(
        self,
1088
        hf_inputs: BatchFeature,
1089
1090
1091
1092
1093
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1094
    @abstractmethod
1095
    def _get_prompt_updates(
1096
        self,
1097
        mm_items: MultiModalDataItems,
1098
        hf_processor_mm_kwargs: Mapping[str, object],
1099
        out_mm_kwargs: MultiModalKwargsItems,
1100
    ) -> Sequence[PromptUpdate]:
1101
1102
        """
        Given the original multi-modal items for this modality
1103
        and HF-processed data, output the updates to perform.
1104

1105
1106
1107
1108
1109
1110
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
1111
1112
        in order to construct
        [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
1113
        for each multi-modal item.
1114
1115
        """
        raise NotImplementedError
1116

1117
1118
1119
1120
1121
1122
    def _bind_and_group_updates(
        self,
        prompt_updates: Sequence[PromptUpdate],
        mm_item_counts: Mapping[str, int],
    ) -> MultiModalPromptUpdates:
        return {
1123
1124
1125
1126
            modality: [
                [update.resolve(item_idx) for update in updates]
                for item_idx in range(mm_item_counts.get(modality, 0))
            ]
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
            for modality, updates in full_groupby_modality(prompt_updates)
        }

    def _get_mm_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> MultiModalPromptUpdates:
        unbound_prompt_updates = self._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )

        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates,
            mm_items.get_all_counts(),
        )

        for modality, prompt_updates in mm_prompt_updates.items():
            for item_idx, item_prompt_updates in enumerate(prompt_updates):
                if len(item_prompt_updates) > 1:
                    logger.warning_once(
                        "Detected %d prompt updates for `mm_items[%r][%s]`. "
                        "Multiple prompt updates per item is now "
                        "deprecated and may be removed in v0.13. "
                        "Instead, please specify dynamic update targets "
                        "in the same prompt update definition by passing "
                        "a function to `PromptUpdate.target`.",
                        len(prompt_updates),
                        modality,
                        item_idx,
                    )

        return mm_prompt_updates

1164
    def _find_mm_placeholders(
1165
1166
        self,
        new_token_ids: list[int],
1167
        mm_prompt_updates: MultiModalPromptUpdates,
1168
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1169
1170
        tokenizer = self.info.get_tokenizer()

1171
        return find_mm_placeholders(new_token_ids, mm_prompt_updates, tokenizer)
1172

1173
    def _get_hf_mm_data(
1174
        self,
1175
        mm_items: MultiModalDataItems,
1176
1177
1178
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1179

1180
1181
1182
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1183

1184
1185
        return processor_data, passthrough_data

1186
1187
1188
    def _call_hf_processor(
        self,
        prompt: str,
1189
1190
1191
1192
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1193
        tok_kwargs: Mapping[str, object],
1194
    ) -> BatchFeature:
1195
1196
1197
1198
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1199
        with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
1200
1201
1202
1203
1204
            return self.info.ctx.call_hf_processor(
                self.info.get_hf_processor(**mm_kwargs),
                dict(text=prompt, **mm_data),
                dict(**mm_kwargs, **tok_kwargs),
            )
1205

1206
    def _hf_processor_applies_updates(
1207
1208
1209
1210
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1211
        tokenization_kwargs: Mapping[str, object],
1212
1213
    ) -> bool:
        """
1214
        Return whether the HF processor applies prompt updates.
1215

1216
1217
        For most HF processors, this should be `True` when multi-modal
        data items are passed, but `False` when multi-modal embeddings
1218
1219
1220
1221
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
1222
1223
            for items in mm_items.values()
        )
1224

1225
    def _apply_hf_processor_text_mm(
1226
        self,
1227
        prompt_text: str,
1228
        mm_items: MultiModalDataItems,
1229
        hf_processor_mm_kwargs: Mapping[str, object],
1230
        tokenization_kwargs: Mapping[str, object],
1231
    ) -> tuple[list[int], BatchFeature, bool]:
1232
        """
1233
1234
        Apply the HF processor on the prompt text and multi-modal data
        together.
1235

1236
        In addition, return whether prompt updates have been applied.
1237
1238
1239
1240
1241
1242
1243
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
1244
            tok_kwargs=tokenization_kwargs,
1245
1246
        )
        processed_data.update(passthrough_data)
1247

1248
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
1249

1250
        is_update_applied = self._hf_processor_applies_updates(
1251
1252
1253
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1254
            tokenization_kwargs=tokenization_kwargs,
1255
1256
        )

1257
        return prompt_ids, processed_data, is_update_applied
1258

1259
    def _apply_hf_processor_text_only(
1260
1261
1262
1263
        self,
        prompt_text: str,
        tokenization_kwargs: Mapping[str, object],
    ) -> list[int]:
1264
        """
1265
        Apply the HF processor on the prompt text only.
1266

1267
1268
1269
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1270
        """
1271
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1272
1273
1274
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
1275
            tokenization_kwargs=tokenization_kwargs,
1276
1277
        )

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
1290
1291
1292
        with the output of
        [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
        on the
1293
1294
1295
1296
1297
1298
1299
1300
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1301
        tokenization_kwargs: Mapping[str, object],
1302
    ) -> BatchFeature:
1303
1304
1305
1306
1307
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
1308
        [`DummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
1309
        to go along with the multi-modal data.
1310
1311
1312
        """
        mm_counts = mm_items.get_all_counts()

1313
        _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
1314
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1315
1316
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1317
            tokenization_kwargs=tokenization_kwargs,
1318
1319
        )

1320
        return mm_processed_data
1321
1322
1323

    def _apply_hf_processor_main(
        self,
1324
        prompt: str | list[int],
1325
1326
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1327
        tokenization_kwargs: Mapping[str, object],
1328
        *,
1329
        enable_hf_prompt_update: bool,
1330
    ) -> tuple[list[int], BatchFeature, bool]:
1331
1332
1333
        """
        Apply the HF processor on the prompt text and multi-modal data.

1334
        In addition, return whether prompt updates have been applied
1335
        (for most HF processors, this should be `True`).
1336

1337
        Note:
1338
            If `enable_hf_prompt_update=False`, we use HF processor
1339
            to perform prompt updates if available; HF processor requires
1340
            that the prompt corresponds to multi-modal items.
1341
1342
        """
        if isinstance(prompt, str):
1343
            if enable_hf_prompt_update:
1344
1345
1346
1347
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1348
                    tokenization_kwargs=tokenization_kwargs,
1349
1350
                )

1351
            prompt_ids = self._apply_hf_processor_text_only(prompt, tokenization_kwargs)
1352
1353
1354
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1355
        mm_processed_data = self._apply_hf_processor_mm_only(
1356
            mm_items=mm_items,
1357
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1358
            tokenization_kwargs=tokenization_kwargs,
1359
1360
        )

1361
        return prompt_ids, mm_processed_data, False
1362

1363
    def _hash_mm_items(
1364
1365
1366
1367
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
1368
        *,
1369
        mm_uuids: MultiModalUUIDDict | None = None,
1370
    ) -> MultiModalHashes:
1371
        """Create MM hashes to be returned.
1372

1373

1374
1375
1376
        Note: When overrides are provided via callers of `apply`,
        `_hash_mm_items` will be bypassed and the overrides will be used.
        """
1377
1378
        model_id = self.info.model_id

1379
        hashes: MultiModalHashes = {}
1380
        mm_uuids = mm_uuids or {}
1381
1382

        for modality, items in mm_items.items():
1383
1384
1385
1386
            if modality in mm_uuids:
                mm_uuids_per_modality = mm_uuids[modality]
                if isinstance(mm_uuids_per_modality, str):
                    mm_uuids_per_modality = [mm_uuids_per_modality]
1387
1388
1389

                # For None entries, compute a hash; otherwise, use provided ID.
                computed: list[str] = []
1390
                for i, item in enumerate(items.get_all_items_for_hash()):
1391
                    item_uuid = mm_uuids_per_modality[i]
1392

1393
                    # NOTE: Even if a item_uuid is provided, we still compute a
1394
1395
1396
                    # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
                    # are provided. This is because the processed multimodal
                    # inputs can be different depending on the processor kwargs.
1397
1398
1399
1400
1401
                    if (
                        item_uuid is None
                        or hf_processor_mm_kwargs
                        or tokenization_kwargs
                    ):
1402
1403
                        # NOTE: use provided hash string to hash with kwargs
                        # if available for better performance.
1404
                        item = item_uuid if item_uuid is not None else item
1405
1406
1407
1408
1409
                        computed.append(
                            MultiModalHasher.hash_kwargs(
                                model_id=model_id,
                                **{modality: item},
                                **hf_processor_mm_kwargs,
1410
1411
1412
                                **tokenization_kwargs,
                            )
                        )
1413
                    else:
1414
                        computed.append(item_uuid)
1415
1416
1417
                hashes[modality] = computed
            else:
                hashes[modality] = [
1418
1419
1420
1421
1422
1423
                    MultiModalHasher.hash_kwargs(
                        model_id=model_id,
                        **{modality: item},
                        **hf_processor_mm_kwargs,
                        **tokenization_kwargs,
                    )
1424
1425
1426
1427
                    for item in items
                ]

        return hashes
1428

1429
1430
    def _get_cache_missing_items(
        self,
1431
        cache: BaseMultiModalProcessorCache,
1432
1433
        mm_data_items: MultiModalDataItems,
        mm_hashes: MultiModalHashes,
1434
    ) -> tuple[MultiModalIsCached, MultiModalDataItems]:
1435
        mm_is_cached = {
1436
            modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
1437
1438
1439
1440
        }

        mm_missing_idxs = {
            modality: [
1441
1442
                idx
                for idx, item_is_cached in enumerate(items_is_cached)
1443
1444
1445
1446
                if not item_is_cached
            ]
            for modality, items_is_cached in mm_is_cached.items()
        }
1447
1448
1449
1450
1451
1452
1453
1454
        mm_missing_data = {}
        for modality, idxs in mm_missing_idxs.items():
            missing_modality_data = []
            for idx in idxs:
                data = mm_data_items[modality][idx]
                if data is None:
                    raise ValueError(
                        f"Cache miss for {modality} at index {idx} "
1455
1456
                        f"but data is not provided."
                    )
1457
1458
1459
                else:
                    missing_modality_data.append(data)
            mm_missing_data[modality] = missing_modality_data
1460

1461
        return mm_is_cached, self._to_mm_items(mm_missing_data)
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473

    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        """
        Override this if other attributes of `ResolvedPromptUpdate`
        also need to be recomputed after retrieving from the cache.
        """
        return replace(cached_update, item_idx=new_item_idx)

1474
1475
    def _merge_mm_kwargs(
        self,
1476
        cache: BaseMultiModalProcessorCache,
1477
        mm_hashes: MultiModalHashes,
1478
        mm_is_cached: MultiModalIsCached,
1479
        mm_missing_kwargs: MultiModalKwargsItems,
1480
1481
        mm_missing_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
1482
1483
1484
1485
1486
        # Need to touch all mm hashes before update to avoid hash in updated
        # list evict during update
        for hashes in mm_hashes.values():
            for item_hash in hashes:
                cache.touch_sender_cache_item(item_hash)
1487

1488
        mm_missing_next_idx = defaultdict[str, int](lambda: 0)
1489

1490
        merged_kwargs = defaultdict[str, list[MultiModalKwargsItem | None]](list)
1491
1492
1493
        merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](
            list
        )
1494
1495
        for modality, hashes in mm_hashes.items():
            missing_kwargs = mm_missing_kwargs.get(modality, [])
1496
            missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
1497
1498
1499
1500

            for item_idx, item_hash in enumerate(hashes):
                if not mm_is_cached[modality][item_idx]:
                    missing_next_idx = mm_missing_next_idx[modality]
1501
1502
                    missing_kwargs_item = missing_kwargs[missing_next_idx]
                    missing_updates_item = missing_prompt_updates[missing_next_idx]
1503

1504
                    mm_missing_next_idx[modality] += 1
1505

1506
                    item = missing_kwargs_item, missing_updates_item
1507
                else:
1508
1509
1510
1511
1512
                    item = None

                kwargs, updates = cache.get_and_update_item(item, item_hash)

                merged_kwargs[modality].append(kwargs)
1513
1514
1515
1516
1517
1518
                merged_prompt_updates[modality].append(
                    [
                        self._recompute_cached_prompt_update(update, item_idx)
                        for update in updates
                    ]
                )
1519

1520
1521
        mm_kwargs = MultiModalKwargsItems(merged_kwargs)
        mm_prompt_updates = dict(merged_prompt_updates)
1522

1523
        return mm_kwargs, mm_prompt_updates
1524
1525
1526

    def _apply_hf_processor(
        self,
1527
        prompt: str | list[int],
1528
1529
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1530
        tokenization_kwargs: Mapping[str, object],
1531
        *,
1532
        mm_uuids: MultiModalUUIDDict | None = None,
1533
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1534
1535
        (
            prompt_ids,
1536
            mm_processed_data,
1537
1538
1539
1540
1541
            is_update_applied,
        ) = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1542
            tokenization_kwargs=tokenization_kwargs,
1543
1544
1545
            enable_hf_prompt_update=True,
        )

1546
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
1547
            mm_processed_data,
1548
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
1549
1550
        )

1551
        # Use overrides if provided; fallback to data-dependent hashing.
1552
        with timed_preprocessor_operation(self.info.ctx, "hashing"):
1553
1554
1555
1556
1557
1558
            mm_hashes = self._hash_mm_items(
                mm_data_items,
                hf_processor_mm_kwargs,
                tokenization_kwargs,
                mm_uuids=mm_uuids,
            )
1559

1560
        mm_prompt_updates = self._get_mm_prompt_updates(
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
            mm_data_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
            hashes=mm_hashes,
            prompt_updates=mm_prompt_updates,
        )

        return prompt_ids, mm_info, is_update_applied
1573

1574
1575
    def _cached_apply_hf_processor(
        self,
1576
        prompt: str | list[int],
1577
1578
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1579
        tokenization_kwargs: Mapping[str, object],
1580
        *,
1581
        mm_uuids: MultiModalUUIDDict | None = None,
1582
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1583
1584
1585
1586
1587
1588
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache

1589
1590
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1591
            return self._apply_hf_processor(
1592
                prompt=prompt,
1593
                mm_data_items=mm_data_items,
1594
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1595
                tokenization_kwargs=tokenization_kwargs,
1596
                mm_uuids=mm_uuids,
1597
1598
            )

1599
        with timed_preprocessor_operation(self.info.ctx, "hashing"):
1600
1601
1602
1603
1604
1605
            mm_hashes = self._hash_mm_items(
                mm_data_items,
                hf_processor_mm_kwargs,
                tokenization_kwargs,
                mm_uuids=mm_uuids,
            )
1606

1607
        with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
1608
1609
1610
1611
1612
            mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
                cache=cache,
                mm_data_items=mm_data_items,
                mm_hashes=mm_hashes,
            )
1613

1614
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1615
        # so we can't apply prompt updates until the new multimodal
1616
1617
1618
        # items are combined with the cached multimodal items
        (
            prompt_ids,
1619
            mm_missing_processed_data,
1620
            is_update_applied,
1621
        ) = self._apply_hf_processor_main(
1622
            prompt=prompt,
1623
            mm_items=mm_missing_data_items,
1624
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1625
            tokenization_kwargs=tokenization_kwargs,
1626
            enable_hf_prompt_update=False,
1627
1628
        )

1629
        mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
1630
            mm_missing_processed_data,
1631
1632
1633
            self._get_mm_fields_config(
                mm_missing_processed_data, hf_processor_mm_kwargs
            ),
1634
1635
        )

1636
1637
1638
1639
        mm_missing_prompt_updates = self._get_mm_prompt_updates(
            mm_missing_data_items,
            hf_processor_mm_kwargs,
            mm_missing_kwargs,
1640
        )
1641

1642
        with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
1643
1644
1645
1646
1647
1648
1649
            mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
                cache,
                mm_hashes=mm_hashes,
                mm_is_cached=mm_is_cached,
                mm_missing_kwargs=mm_missing_kwargs,
                mm_missing_prompt_updates=mm_missing_prompt_updates,
            )
1650
1651
1652

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
1653
            hashes=mm_hashes,
1654
1655
            prompt_updates=mm_prompt_updates,
        )
1656

1657
        return prompt_ids, mm_info, is_update_applied
1658

1659
1660
1661
    def _apply_token_matches(
        self,
        prompt: list[int],
1662
1663
1664
1665
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
1666
1667
1668
1669

    def _apply_text_matches(
        self,
        prompt: str,
1670
1671
1672
1673
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
1674

1675
    def _apply_prompt_updates(
1676
1677
        self,
        token_ids: list[int],
1678
        mm_prompt_updates: MultiModalPromptUpdates,
1679
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
1680
        tokenizer = self.info.get_tokenizer()
1681

1682
1683
1684
1685
        new_token_ids, match_result = self._apply_token_matches(
            token_ids,
            mm_prompt_updates,
        )
1686
1687
1688
1689
1690
1691
1692
1693
1694

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1695
1696
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1697
        if not all(
1698
1699
1700
            all(update_idx is not None for update_idx in update_idxs)
            for update_idxs in match_result.values()
        ):
1701
            new_text, match_result = self._apply_text_matches(
1702
                _seq2text(tokenizer, token_ids, use_cache=False),
1703
                mm_prompt_updates,
1704
1705
            )

1706
            new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False)
1707

1708
        matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
1709
1710
1711
1712
        for modality, update_idxs in match_result.items():
            for item_idx, update_idx in enumerate(update_idxs):
                assert update_idx is not None, (
                    "Failed to apply prompt replacement for "
1713
1714
                    f"mm_items[{modality!r}][{item_idx}]"
                )
1715
1716

                matched_updates[modality].append(
1717
1718
                    [mm_prompt_updates[modality][item_idx][update_idx]]
                )
1719
1720

        placeholders = self._find_mm_placeholders(
1721
1722
            new_token_ids,
            dict(matched_updates),
1723
        )
1724

1725
        return new_token_ids, placeholders
1726

1727
1728
    def _validate_mm_kwargs(
        self,
1729
        mm_kwargs: MultiModalKwargsOptionalItems,
1730
1731
1732
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
1733
            items = mm_kwargs.get(modality, [])
1734
1735
1736
1737
1738
1739
1740
1741
1742

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1743
1744
                    "`_call_hf_processor` and `_get_mm_fields_config`)."
                )
1745

1746
    def _validate_mm_updates(
1747
        self,
1748
        mm_updates: MultiModalPromptUpdates,
1749
        mm_item_counts: Mapping[str, int],
1750
    ) -> None:
1751
        for modality, item_count in mm_item_counts.items():
1752
            placeholders = mm_updates.get(modality, [])
1753

1754
            if len(placeholders) != item_count:
1755
                raise RuntimeError(
1756
                    f"Expected there to be {item_count} prompt updates "
1757
                    f"corresponding to {item_count} {modality} items, but "
1758
                    f"instead found {len(placeholders)} prompt updates! "
1759
1760
1761
                    "This is likely because you forgot to include input "
                    "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
                    "in the prompt. If the model has a chat template, make "
1762
1763
                    "sure you have applied it before calling `LLM.generate`."
                )
1764

1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt placeholders "
                    f"corresponding to {item_count} {modality} items, but "
                    f"instead found {len(placeholders)} prompt placeholders! "
                    "Make sure the implementation of `_call_hf_processor` and "
1779
1780
                    "`_get_mm_fields_config` are consistent with each other."
                )
1781

1782
1783
1784
1785
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
1786
        mm_kwargs: MultiModalKwargsOptionalItems,
1787
        mm_prompt_updates: MultiModalPromptUpdates,
1788
        is_update_applied: bool,
1789
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
1790
        mm_item_counts = mm_items.get_all_counts()
1791
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
1792
        self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
1793

1794
        if is_update_applied:
1795
1796
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
1797
                mm_prompt_updates,
1798
            )
1799
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1800
        else:
1801
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
1802
                prompt_ids,
1803
                mm_prompt_updates,
1804
            )
1805
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1806

1807
        return prompt_ids, mm_placeholders
1808
1809
1810

    def apply(
        self,
1811
        prompt: str | list[int],
1812
1813
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1814
        tokenization_kwargs: Mapping[str, object] | None = None,
1815
        *,
1816
        mm_uuids: MultiModalUUIDDict | None = None,
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1831
1832
1833
1834
        request_id = get_current_request_id()
        if request_id is not None:
            self.info.ctx.create_timing_stats(request_id)

1835
1836
        mm_items = self._to_mm_items(mm_data)

1837
1838
1839
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

1840
1841
        (
            prompt_ids,
1842
            mm_info,
1843
1844
1845
1846
1847
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
1848
            tokenization_kwargs=tokenization_kwargs,
1849
            mm_uuids=mm_uuids,
1850
1851
        )

1852
        # NOTE: tokenization_kwargs are not required to init processor
1853
        with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
1854
1855
1856
1857
1858
1859
1860
            prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
                mm_items=mm_items,
                prompt_ids=prompt_ids,
                mm_kwargs=mm_info.kwargs,
                mm_prompt_updates=mm_info.prompt_updates,
                is_update_applied=is_update_applied,
            )
1861

1862
1863
1864
1865
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1866

1867
        return MultiModalInputs(
1868
            type="multimodal",
1869
            prompt_token_ids=prompt_ids,
1870
1871
            mm_kwargs=mm_info.kwargs,
            mm_hashes=mm_info.hashes,
1872
            mm_placeholders=mm_placeholder_ranges,
1873
        )
1874
1875
1876
1877
1878
1879


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
    @abstractmethod
    def create_encoder_prompt(
        self,
1880
        prompt: str | list[int],
1881
        mm_data: MultiModalDataDict,
1882
    ) -> str | list[int]:
1883
        """
1884
        Create input prompt for the encoder. HF processor will be applied on
1885
1886
        this prompt during profiling and generation.
        """
1887
1888
        raise NotImplementedError

1889
1890
    def create_decoder_prompt(
        self,
1891
        prompt: str | list[int],
1892
        mm_data: MultiModalDataDict,
1893
    ) -> str | list[int]:
1894
1895
1896
        """Create input prompt for the decoder."""
        return prompt

1897
    def _get_enc_dec_inputs(
1898
        self,
1899
        prompt: str | list[int],
1900
        mm_data: MultiModalDataDict,
1901
1902
        encoder_inputs: MultiModalInputs,
    ):
1903
        tokenizer = self.info.get_tokenizer()
1904
1905
        decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt_raw, str):
1906
1907
            decoder_prompt_ids = tokenizer.encode(
                decoder_prompt_raw, add_special_tokens=False
1908
            )
1909
        else:
1910
            decoder_prompt_ids = decoder_prompt_raw
1911
1912
1913

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
1914
1915
            **encoder_inputs,
        )
1916
        mm_inputs["prompt_token_ids"] = decoder_prompt_ids
1917
        return mm_inputs
1918
1919
1920

    def apply(
        self,
1921
        prompt: str | list[int],
1922
1923
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1924
        tokenization_kwargs: Mapping[str, object] | None = None,
1925
        *,
1926
        mm_uuids: MultiModalUUIDDict | None = None,
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1940
            tokenization_kwargs,
1941
            mm_uuids=mm_uuids,
1942
1943
1944
1945
1946
1947
1948
        )

        return self._get_enc_dec_inputs(
            prompt=prompt,
            mm_data=mm_data,
            encoder_inputs=encoder_inputs,
        )