processing.py 72.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import time
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
7
from dataclasses import dataclass, field, replace
8
from enum import Enum
9
from functools import lru_cache
10
11
12
13
14
15
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    NamedTuple,
    Protocol,
16
    TypeAlias,
17
18
19
    cast,
    overload,
)
20

21
import regex as re
22
import torch
23
from typing_extensions import TypeVar, assert_never
24

25
from vllm.logger import init_logger
26
from vllm.tokenizers import TokenizerLike
27
from vllm.transformers_utils.processor import cached_processor_from_config
28
29
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
30
from vllm.utils.jsontree import JSONTree, json_map_leaves
31

32
from .hasher import MultiModalHasher
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from .inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalKwargsOptionalItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from .parse import (
    DictEmbeddingItems,
    EmbeddingItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
50
51

if TYPE_CHECKING:
52
53
54
55
    from transformers.configuration_utils import PretrainedConfig
    from transformers.feature_extraction_utils import BatchFeature
    from transformers.processing_utils import ProcessorMixin

56
    from vllm.config import ModelConfig
57

58
    from .cache import BaseMultiModalProcessorCache
59
    from .profiling import BaseDummyInputsBuilder
60
61
62
63
64
65
66
67
else:
    PretrainedConfig = object
    BatchFeature = object
    ProcessorMixin = object

    ModelConfig = object

    BaseMultiModalProcessorCache = object
68

69
logger = init_logger(__name__)
70
71

_S = TypeVar("_S", str, list[int])
72

73
PromptSeq: TypeAlias = str | list[int]
74
"""A token sequence (list of token IDs) or text."""
75

76

77
78
@lru_cache(maxsize=2048)
def _cached_encode(
79
    tokenizer: TokenizerLike,
80
81
    text: str,
    *,
82
    add_special_tokens: bool = True,
83
) -> list[int]:
84
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
85
86
87
88


@lru_cache(maxsize=2048)
def _cached_decode(
89
    tokenizer: TokenizerLike,
90
91
    token_ids: tuple[int, ...],
    *,
92
    skip_special_tokens: bool = False,
93
) -> str:
94
    return tokenizer.decode(list(token_ids), skip_special_tokens=skip_special_tokens)
95
96


97
98
99
100
101
102
def _seq2text(
    tokenizer: TokenizerLike | None,
    seq: PromptSeq,
    *,
    use_cache: bool = True,
) -> str:
103
104
105
    if isinstance(seq, str):
        return seq

106
107
108
109
    if tokenizer is None:
        raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`")

    if not use_cache:
110
        return tokenizer.decode(seq)
111

112
113
114
    return _cached_decode(tokenizer, tuple(seq))


115
116
117
118
119
120
def _seq2tokens(
    tokenizer: TokenizerLike | None,
    seq: PromptSeq,
    *,
    use_cache: bool = True,
) -> list[int]:
121
    if isinstance(seq, str):
122
123
124
125
        if tokenizer is None:
            raise ValueError("You cannot encode text when `skip_tokenizer_init=True`")

        if not use_cache:
126
            return tokenizer.encode(seq, add_special_tokens=False)
127

128
129
130
131
132
        return _cached_encode(tokenizer, seq, add_special_tokens=False)

    return seq


133
134
135
class _GetMatchIndex(Protocol):
    def __call__(
        self,
136
        tokenizer: TokenizerLike | None,
137
138
        prompt: PromptSeq,
        start_idx: int = 0,
139
    ) -> int | None: ...
140
141


142
143
144
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
145

146
    get_match_index: _GetMatchIndex
147
148
149
150
151
152
153
154
155
156


class PromptIndexTargets:
    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
157
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
158
159
160
161
162
163
164
165

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
166
            tokenizer: TokenizerLike | None,
167
            prompt: PromptSeq,
168
            start_idx: int = 0,
169
        ) -> int | None:
170
171
172
            if start_idx != 0:
                return None

173
174
175
            prefix = seq

            if isinstance(prompt, str):
176
177
                # Make both `str`
                prefix = _seq2text(tokenizer, prefix, use_cache=False)
178
            else:
179
180
                # Make both `list[int]`
                prefix = _seq2tokens(tokenizer, prefix, use_cache=False)
181
182
183
184
185
186
187
188
189
190
191
192
193

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
194
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
195
196


197
UpdateTarget: TypeAlias = PromptSeq | PromptIndex
198
199
200
201
"""
The token sequence or text to update.
"""

202
PromptUpdateTarget: TypeAlias = Callable[[int], UpdateTarget] | UpdateTarget
203
204
205
206
207
208
209
210
211
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""

212

213
@dataclass
214
class PromptUpdateDetails(Generic[_S]):
215
    """Details about the token sequence or text that are part of the update."""
216

217
    full: _S
218
    """The full content."""
219

220
    is_embed: Callable[[TokenizerLike | None, PromptSeq], torch.Tensor] | None = None
221
    """
222
223
224
    Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
    return a boolean mask of shape `(len(full),)` indicating which positions
    of `full` to assign embeddings to.
225
226
227
228

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
229
    [`SupportsMultiModal.embed_multimodal`][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
230
231
232
    """

    @staticmethod
233
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
234
235
236
237
238
239
240
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":
241
242
        def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
            embed_token_ids = _seq2tokens(tokenizer, embed_text, use_cache=False)
243
            token_ids = _seq2tokens(tokenizer, full)
244
245

            return torch.isin(
246
                torch.tensor(token_ids),
247
248
249
250
251
252
253
254
255
256
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
257
        def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor:
258
259
260
261
262
            token_ids = _seq2tokens(tokenizer, full)

            return torch.tensor(token_ids) == embed_token_id

        return PromptUpdateDetails(full=seq, is_embed=is_embed)
263
264


265
PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails
266
"""
267
The token sequence or text that are part of the update.
268

269
If only part of the content corresponds to feature placeholders, you can
270
271
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
272
"""
273

274
PromptUpdateContent: TypeAlias = Callable[[int], PromptUpdateInfo] | PromptUpdateInfo
275
"""
276
277
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
278
279
280
281
282
283
284
285
286
287
288
289
290
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
291
class PromptUpdate(ABC):
292
293
294
295
296
297
298
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

299
    target: PromptUpdateTarget
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

314
    def _resolve_target(self, item_idx: int) -> UpdateTarget:
315
316
317
318
        target = self.target
        if callable(target):
            target = target(item_idx)

319
        return target
320

321
    def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
322
323
324
325
326
327
328
        content = self.content
        if callable(content):
            content = content(item_idx)

        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

329
        return content
330

331
    def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
332
333
334
335
336
337
338
339
340
        """
        Given the index of the processed item within
        [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
        output a copy of this object with its lazy attributes resolved.
        """
        return ResolvedPromptUpdate(
            modality=self.modality,
            item_idx=item_idx,
            mode=self.mode,
341
342
            target=self._resolve_target(item_idx),
            content=self._resolve_content(item_idx),
343
344
        )

345

346
@dataclass
347
348
349
350
351
352
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

353
354
    For each image, insert a number of `<image>` feature placeholders
    equal to the feature size of the vision encoder after the `<s>` token:
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

    ```python
    PromptInsertion(
        modality="image",
        target="<s>",
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the start of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.start(),
        insertion="<image>" * image_feature_size,
    )
    ```

374
    Insert these tokens after a prefix `Images:`:
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.prefix("Images:"),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the end of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.end(),
        insertion="<image>" * image_feature_size,
    )
    ```
393
394
395
396
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
397
398
399
400
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to insert right after
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
417
418
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
419
420
421

    Example:

422
423
    For each image, replace one `<image>` input placeholder in the prompt
    with a number of `<image>` feature placeholders
424
425
426
427
428
429
430
431
432
433
    equal to the feature size of the vision encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement="<image>" * image_feature_size,
    )
    ```

434
435
    As above, but further pad the feature placeholders with `<image_bos>`
    and `<image_eos>`, which are not supposed to be passed to the vision
436
437
438
439
440
441
442
    encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement=PromptUpdateDetails(
443
444
445
446
447
448
449
            full="".join(
                [
                    "<image_bos>",
                    "<image>" * image_feature_size,
                    "<image_eos>",
                ]
            ),
450
451
452
453
454
455
456
457
458
459
460
461
462
            features="<image>" * image_feature_size,
        ),
    )
    ```

    To avoid unnecessary tokenization during prompt replacement,
    we recommended passing token sequences instead of text:

    ```python
    PromptReplacement(
        modality="image",
        target=[image_token_id],
        replacement=PromptUpdateDetails(
463
464
465
            full=(
                [image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id]
            ),
466
467
468
469
            features=[image_token_id] * image_feature_size,
        ),
    )
    ```
470
471
    """

472
    replacement: PromptUpdateContent = field(repr=False)
473
    """
474
475
476
477
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to replace
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
478

479
480
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
481
482
    """

483
484
485
486
487
488
489
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
490
491


492
493
494
class _HasModalityAttr(Protocol):
    modality: str

495

496
497
class _HasModalityProp(Protocol):
    @property
498
    def modality(self) -> str: ...
499
500


501
_M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
502
503
504


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
505
506
    """
    Convenience function to apply
507
    [`full_groupby`][vllm.utils.collection_utils.full_groupby]
508
509
    based on modality.
    """
510
511
512
    return full_groupby(values, key=lambda x: x.modality)


513
514
515
516
517
518
519
class PromptTargetMatch(NamedTuple):
    start_idx: int
    end_idx: int


@dataclass(frozen=True)
class ResolvedPromptUpdate:
520
    """
521
522
    A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
    lazy attributes resolved, apart from those related to tokenization.
523
    """
524

525
526
    modality: str
    """The modality for which the update is made."""
527

528
529
    item_idx: int
    """The index within `modality` of the item this update pertains to."""
530

531
532
    mode: UpdateMode
    """Defines how to update the prompt."""
533

534
    target: UpdateTarget
535
    """The token sequence (or text) to update."""
536

537
    content: PromptUpdateDetails = field(repr=False)
538
    """The placeholder tokens that are part of the update."""
539

540
541
542
    def iter_token_matches(
        self,
        prompt: list[int],
543
        tokenizer: TokenizerLike | None,
544
545
546
547
548
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
549

550
551
552
553
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
554

555
            return
556

557
558
        target_token_ids = _seq2tokens(tokenizer, target)

559
        for match in iter_token_matches(prompt, target_token_ids, start_idx=start_idx):
560
            yield PromptTargetMatch(match.start_idx, match.end_idx)
561

562
563
564
    def iter_text_matches(
        self,
        prompt: str,
565
        tokenizer: TokenizerLike | None,
566
567
568
569
570
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
571

572
573
574
575
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
576

577
            return
578

579
580
        target_text = _seq2text(tokenizer, target)

581
        for match in re.finditer(re.escape(target_text), prompt, pos=start_idx):
582
583
584
585
            yield PromptTargetMatch(match.start(), match.end())

    def iter_matches(
        self,
586
        prompt: list[int] | str,
587
        tokenizer: TokenizerLike | None,
588
589
590
591
592
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        if isinstance(prompt, str):
593
            return self.iter_text_matches(prompt, tokenizer, start_idx=start_idx)
594
595

        return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
596

597
598
599
600
601
602
603
604
605
    def with_target(self, target: UpdateTarget):
        return replace(self, target=target)

    def with_content(self, content: PromptUpdateInfo):
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

        return replace(self, content=content)

606

607
608
609
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
610
611


612
613
614
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
615
616
    *,
    start_idx: int = 0,
617
) -> Generator[_TokenMatch]:
618
    """
619
    Yield each occurrence of `match_ids` in `token_ids`.
620
621
622
623

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
624
    match_len = len(match_ids)
625

626
627
    if match_len == 0:
        return
628

629
    while start_idx < prompt_len - match_len + 1:
630
        end_idx = start_idx + match_len
631

632
633
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
634
635
636
637
638

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
639
640


641
642
643
644
645
646
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
647
648
    Replace each occurrence of `match_ids` in `token_ids`
    with `new_ids`.
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


668
@dataclass
669
class PlaceholderFeaturesInfo:
670
    modality: str
671
    item_idx: int
672
    start_idx: int
673
    tokens: list[int]
674
    is_embed: torch.Tensor | None
675
676
677

    @property
    def length(self) -> int:
678
        return len(self.tokens)
679
680

    def to_range(self) -> PlaceholderRange:
681
682
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
683
684
685
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
686
            is_embed=self.is_embed,
687
        )
688
689


690
_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
691
692


693
694
695
def _find_matches(
    prompt: _S,
    mm_prompt_updates: "MultiModalPromptUpdates",
696
    tokenizer: TokenizerLike | None,
697
698
699
    *,
    prev_end_idx: int = 0,
    current_result: "MultiModalPromptUpdatesApplyResult",
700
701
) -> tuple[UpdateMode | None, list[_MatchToApply]]:
    mode: UpdateMode | None = None
702
703
704
705
706
707
708
709
710
711
712
713
    mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()

    for modality, modality_updates in mm_prompt_updates.items():
        for item_idx, item_updates in enumerate(modality_updates):
            if current_result[modality][item_idx] is not None:
                continue  # Updates have already been applied for this item

            for update_idx, update in enumerate(item_updates):
                if (modality, item_idx) in mm_matches:
                    break  # Already found a match for this item

                for match in update.iter_matches(
714
715
716
                    prompt,
                    tokenizer,
                    start_idx=prev_end_idx,
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
                ):
                    # All matches should share the same mode
                    if mode is None:
                        mode = update.mode
                    elif mode != update.mode:
                        continue

                    mm_matches[(modality, item_idx)] = match, update_idx
                    break  # Get only the first valid match per item

    # Prioritize earlier matches
    matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])

    # To avoid conflicts, only replace one non-empty item at a time
    if mode == UpdateMode.REPLACE:
        matches_to_apply_ = list[_MatchToApply]()
        has_non_empty_matches = False

        for item in matches_to_apply:
            _, (match, _) = item
            if match.start_idx == match.end_idx:
                matches_to_apply_.append(item)
            elif not has_non_empty_matches:
                has_non_empty_matches = True
                matches_to_apply_.append(item)

        matches_to_apply = matches_to_apply_

    return mode, matches_to_apply
746
747


748
749
750
751
752
753
754
755
756
757
def _all_items_found(
    mm_item_counts: dict[str, int],
    mm_found_counts: dict[str, int],
) -> bool:
    return all(
        item_idx >= mm_item_counts[modality]
        for modality, item_idx in mm_found_counts.items()
    )


758
def _apply_matches(
759
    prompt: _S,
760
    mm_prompt_updates: "MultiModalPromptUpdates",
761
    tokenizer: TokenizerLike | None,
762
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
763
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
764

765
    out_seqs = list[str | list[int]]()
766
    out_result: MultiModalPromptUpdatesApplyResult = {
767
        m: [None] * len(items) for m, items in mm_prompt_updates.items()
768
    }
769

770
    # Early exit if no items to find
771
772
773
774
775
776
    mm_found_counts = {
        m: sum(r is not None for r in res) for m, res in out_result.items()
    }
    if _all_items_found(mm_item_counts, mm_found_counts):
        return [prompt], out_result

777
778
    prev_end_idx = 0
    while True:
779
780
781
782
783
784
785
        mode, matches_to_apply = _find_matches(
            prompt,
            mm_prompt_updates,
            tokenizer,
            prev_end_idx=prev_end_idx,
            current_result=out_result,
        )
786

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        if mode is None:
            break  # No more matches to find

        for (modality, item_idx), (match, update_idx) in matches_to_apply:
            matched_update = mm_prompt_updates[modality][item_idx][update_idx]
            matched_content = matched_update.content.full

            if mode == UpdateMode.INSERT:
                end_idx_to_insert = match.end_idx
            elif mode == UpdateMode.REPLACE:
                end_idx_to_insert = match.start_idx
            else:
                assert_never(mode)

            out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
            out_seqs.append(
                _seq2text(tokenizer, matched_content)
                if isinstance(prompt, str)
                else _seq2tokens(tokenizer, matched_content)
            )
            out_result[modality][item_idx] = update_idx

            # Exclude overlapping matches
            prev_end_idx = match.end_idx

        # Early exit if all items found
        mm_found_counts = {
            m: sum(r is not None for r in res) for m, res in out_result.items()
        }
        if _all_items_found(mm_item_counts, mm_found_counts):
            break
818
819
820

    out_seqs.append(prompt[prev_end_idx:])

821
    return cast(list[_S], out_seqs), out_result
822
823


824
def apply_token_matches(
825
    prompt: list[int],
826
    mm_prompt_updates: "MultiModalPromptUpdates",
827
    tokenizer: TokenizerLike | None,
828
829
830
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
831

832
833
834
835
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
836
    token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
837

838
    return flatten_2d_lists(token_id_seqs), result
839
840


841
def apply_text_matches(
842
    prompt: str,
843
    mm_prompt_updates: "MultiModalPromptUpdates",
844
    tokenizer: TokenizerLike | None,
845
846
847
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
848

849
850
851
852
853
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
    texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
854

855
    return "".join(texts), result
856
857


858
def _iter_placeholders(
859
    prompt: list[int],
860
    mm_prompt_updates: "MultiModalPromptUpdates",
861
    tokenizer: TokenizerLike | None,
862
) -> Iterable[PlaceholderFeaturesInfo]:
863
    """
864
    Yield each set of placeholder tokens found in `prompt`.
865
866
867

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
868
    appears earlier in `mm_prompt_updates` takes priority.
869

870
871
    Note that empty matches are ignored.
    """
872
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
873
    item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
874

875
876
    if _all_items_found(mm_item_counts, item_idx_by_modality):
        return
877

878
    prompt_len = len(prompt)
879
    start_idx = 0
880

881
882
883
    while start_idx < prompt_len:
        found = False

884
        for modality, modality_updates in mm_prompt_updates.items():
885
886
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
887
                continue
888

889
890
            for update in modality_updates[item_idx]:
                content = update.content
891
                content_tokens_full = _seq2tokens(tokenizer, content.full)
892
893
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
894

895
                if content_len_full == 0 or end_idx_full > prompt_len:
896
897
                    continue

898
                if prompt[start_idx:end_idx_full] == content_tokens_full:
899
900
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
901
                        content_is_embed = content_is_embed(tokenizer, content.full)
902
903
904
905
906
907
908
909

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
910

911
                    # Exclude overlapping matches
912
                    start_idx = end_idx_full
913
914
915
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
916

917
            if found:
918
919
920
                if _all_items_found(mm_item_counts, item_idx_by_modality):
                    return

921
                break  # Go back to the outer while loop
922
923
924

        if not found:
            start_idx += 1
925
926


927
928
def find_mm_placeholders(
    prompt: list[int],
929
    mm_prompt_updates: "MultiModalPromptUpdates",
930
    tokenizer: TokenizerLike | None,
931
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
932
    it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
933
934
935
    return dict(full_groupby_modality(it))


936
_T = TypeVar("_T")
937
938
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
939
940
941
942
943
944
945
946
947


@dataclass(frozen=True)
class InputProcessingContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

948
949
    model_config: ModelConfig
    """The configuration of the model."""
950

951
    tokenizer: TokenizerLike | None
952
953
    """The tokenizer used to tokenize the inputs."""

954
955
956
957
958
959
960
961
    def get_tokenizer(self) -> TokenizerLike:
        if self.tokenizer is None:
            raise ValueError(
                "You cannot pass text prompts when `skip_tokenizer_init=True`"
            )

        return self.tokenizer

962
    @overload
963
    def get_hf_config(self, /) -> PretrainedConfig: ...
964
965
966
967

    @overload
    def get_hf_config(
        self,
968
        typ: type[_C] | tuple[type[_C], ...],
969
        /,
970
    ) -> _C: ...
971
972
973

    def get_hf_config(
        self,
974
        typ: type[Any] | tuple[type[Any], ...] | None = None,
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
        /,
    ) -> Any:
        """
        Get the HuggingFace configuration
        (`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the configuration is not of the specified type.
        """
        if typ is None:
            from transformers.configuration_utils import PretrainedConfig

            typ = PretrainedConfig

        hf_config = self.model_config.hf_config
        if not isinstance(hf_config, typ):
992
993
994
995
996
            raise TypeError(
                "Invalid type of HuggingFace config. "
                f"Expected type: {typ}, but "
                f"found type: {type(hf_config)}"
            )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019

        return hf_config

    def get_hf_image_processor_config(self) -> dict[str, Any]:
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

    @overload
1020
    def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ...
1021
1022
1023
1024

    @overload
    def get_hf_processor(
        self,
1025
        typ: type[_P] | tuple[type[_P], ...],
1026
1027
        /,
        **kwargs: object,
1028
    ) -> _P: ...
1029
1030
1031

    def get_hf_processor(
        self,
1032
        typ: type[Any] | tuple[type[Any], ...] | None = None,
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        /,
        **kwargs: object,
    ) -> Any:
        """
        Get the HuggingFace processor
        (`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
        if typ is None:
            from transformers.processing_utils import ProcessorMixin

            typ = ProcessorMixin

        return cached_processor_from_config(
1050
            self.model_config,
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
            processor_cls=typ,
            tokenizer=self.tokenizer,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

        return typ(**merged_kwargs)

    def _postprocess_output(
        self,
        output: JSONTree,
    ) -> JSONTree:
        def _postprocess_one(x: object):
            if isinstance(x, torch.Tensor):  # noqa: SIM102
                # This mimics the behavior of transformers.BatchFeature
                if x.is_floating_point():
                    x = x.to(dtype=self.model_config.dtype)

            return x

        return json_map_leaves(_postprocess_one, output)

    def call_hf_processor(
        self,
1091
        hf_processor: ProcessorMixin,
1092
1093
1094
1095
1096
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
        *,
        num_tries: int = 1,
        max_tries: int = 5,
1097
    ) -> BatchFeature | JSONTree:
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        """
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
        """
        assert callable(hf_processor)

        mm_config = self.model_config.get_multimodal_config()
        merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)

        allowed_kwargs = get_allowed_kwarg_only_overrides(
            hf_processor,
            merged_kwargs,
            requires_kw_only=False,
            allow_var_kwargs=True,
        )

        try:
1115
            output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
1116
1117
        except Exception as exc:
            # See https://github.com/huggingface/tokenizers/issues/537
1118
1119
1120
1121
1122
1123
            if (
                isinstance(exc, RuntimeError)
                and exc
                and exc.args[0] == "Already borrowed"
                and num_tries < max_tries
            ):
1124
1125
                logger.warning(
                    "Failed to acquire tokenizer in current thread. "
1126
1127
1128
1129
                    "Retrying (%d/%d)...",
                    num_tries,
                    max_tries,
                )
1130
1131
1132
1133
1134
1135
1136
1137
1138
                time.sleep(0.5)
                return self.call_hf_processor(
                    hf_processor,
                    data,
                    kwargs,
                    num_tries=num_tries + 1,
                    max_tries=max_tries,
                )

1139
1140
1141
1142
            msg = (
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={allowed_kwargs}"
            )
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

            raise ValueError(msg) from exc

        # this emulates output.to(dtype=self.model_config.dtype)
        from transformers.feature_extraction_utils import BatchFeature

        if isinstance(output, BatchFeature):
            output_ = self._postprocess_output(output.data)
            return BatchFeature(output_)

        logger.warning_once(
            "%s did not return `BatchFeature`. "
            "Make sure to match the behaviour of `ProcessorMixin` when "
            "implementing custom processors.",
            type(hf_processor).__name__,
        )

        return self._postprocess_output(output)


1163
class BaseProcessingInfo:
1164
    """Base class to provide the information necessary for data processing."""
1165

1166
1167
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1168

1169
1170
1171
1172
1173
1174
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

1175
    def get_tokenizer(self) -> TokenizerLike:
1176
        return self.ctx.get_tokenizer()
1177

1178
    def get_hf_config(self) -> PretrainedConfig:
1179
1180
        return self.ctx.get_hf_config()

1181
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1182
1183
1184
1185
1186
1187
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1188
    @abstractmethod
1189
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

1200
1201
1202
1203
1204
1205
1206
1207
1208
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

1209
1210
1211
1212
1213
            allowed_limits[modality] = (
                user_limit
                if supported_limit is None
                else min(user_limit, supported_limit)
            )
1214
1215
1216

        return allowed_limits

1217
1218
1219
1220
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1221
    ) -> Mapping[str, int] | None:
1222
1223
        """
        Return the maximum number of tokens per item of for each modality.
1224

1225
1226
1227
1228
        When `None` (the default) is returned, vLLM will generate dummy inputs
        (images/videos) at maximum possible sizes and process them to determine
        the maximum token count per modality.

1229
1230
1231
1232
1233
        This approach works but can be very slow for certain models (e.g.,
        Qwen2.5-VL), leading to very long startup time. For better performance,
        each model can override this method to return pre-computed maximum token
        counts, avoiding the need for dummy input generation and processing.

1234
        Note:
1235
            The maximum number of tokens per item of each modality returned
1236
1237
1238
1239
            from this function should respect the model's maximum sequence
            length and the maximum number of items of each modality allowed,
            and agree with dummy inputs (images/videos) at maximum possible
            sizes.
1240
1241
1242
        """
        return None

1243
1244

_I = TypeVar("_I", bound=BaseProcessingInfo)
1245

1246
1247
MultiModalHashes = dict[str, list[str]]
"""
1248
1249
1250
1251
1252
1253
1254
A collection of the multi-modal hash for each item, with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""

MultiModalIsCached = dict[str, list[bool]]
"""
A collection of the `is_cached` flag for each item, with a similar structure as
1255
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
1256
1257
"""

1258
MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
1259
1260
1261
1262
1263
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""

1264
MultiModalPromptUpdatesApplyResult = Mapping[str, list[int | None]]
1265
1266
1267
1268
1269
1270
1271
"""
For an item `MultiModalPromptUpdates[k][i]`,
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
`ResolvedPromptUpdate` instances have been applied.
"""

1272
1273

class MultiModalProcessingInfo(NamedTuple):
1274
    kwargs: MultiModalKwargsOptionalItems
1275
    hashes: MultiModalHashes
1276
1277
    prompt_updates: MultiModalPromptUpdates

1278
1279

class BaseMultiModalProcessor(ABC, Generic[_I]):
1280
    """
1281
    Abstract base class to process multi-modal inputs to be used in vLLM.
1282

1283
    Not to be confused with `transformers.ProcessorMixin`.
1284
1285
    """

1286
1287
1288
1289
1290
    def __init__(
        self,
        info: _I,
        dummy_inputs: "BaseDummyInputsBuilder[_I]",
        *,
1291
        cache: BaseMultiModalProcessorCache | None = None,
1292
    ) -> None:
1293
1294
        super().__init__()

1295
1296
        self.info = info
        self.dummy_inputs = dummy_inputs
1297
        self.cache = cache
1298

1299
1300
        self.data_parser = self._get_data_parser()

1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
        # Avoid unnecessary recomputation
        self._supported_mm_limits = self.info.get_supported_mm_limits()
        self._allowed_mm_limits = self.info.get_allowed_mm_limits()

    @property
    def supported_mm_limits(self):
        return self._supported_mm_limits

    @property
    def allowed_mm_limits(self):
        return self._allowed_mm_limits

1313
    def __call__(
1314
        self,
1315
1316
        prompt: str,
        mm_data: MultiModalDataDict,
1317
        hf_processor_mm_kwargs: Mapping[str, object],
1318
        *,
1319
        mm_uuids: MultiModalUUIDDict | None = None,
1320
    ) -> MultiModalInputs:
1321
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
1322

1323
1324
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1325
        Construct a parser to preprocess multi-modal data items
1326
1327
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1328
1329

        You can support additional modalities by creating a subclass
1330
1331
        of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
        that has additional subparsers.
1332
        """
1333
1334
1335
1336
1337
1338
1339
1340
1341
        # Get expected hidden size for embedding validation if mm_embeds enabled
        # This validates hidden dimensions to prevent vulnerabilities: embeddings
        # with correct ndim but wrong shape could cause crashes at inference time
        mm_config = self.info.ctx.model_config.get_multimodal_config()
        expected_hidden_size = None
        if mm_config.enable_mm_embeds:
            expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size()

        return MultiModalDataParser(expected_hidden_size=expected_hidden_size)
1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
    def validate_num_items(
        self,
        modality: str,
        num_items: int,
    ) -> None:
        supported_limit = self.supported_mm_limits.get(modality, 0)
        allowed_limit = self.allowed_mm_limits.get(modality, 0)

        if supported_limit is None:
            supported_limit = allowed_limit

        limit = min(supported_limit, allowed_limit)

        if num_items > limit:
1357
            msg = f"At most {limit} {modality}(s) may be provided in one prompt."
1358
1359
1360
1361
1362
1363

            if num_items <= supported_limit:
                msg += " Set `--limit-mm-per-prompt` to increase this limit."

            raise ValueError(msg)

1364
    def _to_mm_items(
1365
1366
1367
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1368
        """
1369
1370
1371
1372
1373
        Normalize
        [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
        to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1374
        """
1375
        mm_items = self.data_parser.parse_mm_data(mm_data)
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385

        mm_config = self.info.ctx.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            for modality, items in mm_items.items():
                if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
                    raise ValueError(
                        f"You must set `--enable-mm-embeds` to input "
                        f"`{modality}_embeds`"
                    )

1386
        for modality, items in mm_items.items():
1387
            self.validate_num_items(modality, len(items))
1388
1389

        return mm_items
1390

1391
1392
1393
    @abstractmethod
    def _get_mm_fields_config(
        self,
1394
        hf_inputs: BatchFeature,
1395
1396
1397
1398
1399
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1400
    @abstractmethod
1401
    def _get_prompt_updates(
1402
        self,
1403
        mm_items: MultiModalDataItems,
1404
        hf_processor_mm_kwargs: Mapping[str, object],
1405
        out_mm_kwargs: MultiModalKwargsItems,
1406
    ) -> Sequence[PromptUpdate]:
1407
1408
        """
        Given the original multi-modal items for this modality
1409
        and HF-processed data, output the updates to perform.
1410

1411
1412
1413
1414
1415
1416
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
1417
1418
        in order to construct
        [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
1419
        for each multi-modal item.
1420
1421
        """
        raise NotImplementedError
1422

1423
1424
1425
1426
1427
1428
    def _bind_and_group_updates(
        self,
        prompt_updates: Sequence[PromptUpdate],
        mm_item_counts: Mapping[str, int],
    ) -> MultiModalPromptUpdates:
        return {
1429
1430
1431
1432
            modality: [
                [update.resolve(item_idx) for update in updates]
                for item_idx in range(mm_item_counts.get(modality, 0))
            ]
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            for modality, updates in full_groupby_modality(prompt_updates)
        }

    def _get_mm_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> MultiModalPromptUpdates:
        unbound_prompt_updates = self._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )

        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates,
            mm_items.get_all_counts(),
        )

        for modality, prompt_updates in mm_prompt_updates.items():
            for item_idx, item_prompt_updates in enumerate(prompt_updates):
                if len(item_prompt_updates) > 1:
                    logger.warning_once(
                        "Detected %d prompt updates for `mm_items[%r][%s]`. "
                        "Multiple prompt updates per item is now "
                        "deprecated and may be removed in v0.13. "
                        "Instead, please specify dynamic update targets "
                        "in the same prompt update definition by passing "
                        "a function to `PromptUpdate.target`.",
                        len(prompt_updates),
                        modality,
                        item_idx,
                    )

        return mm_prompt_updates

1470
    def _find_mm_placeholders(
1471
1472
        self,
        new_token_ids: list[int],
1473
        mm_prompt_updates: MultiModalPromptUpdates,
1474
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1475
1476
        tokenizer = self.info.get_tokenizer()

1477
        return find_mm_placeholders(new_token_ids, mm_prompt_updates, tokenizer)
1478

1479
    def _get_hf_mm_data(
1480
        self,
1481
        mm_items: MultiModalDataItems,
1482
1483
1484
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1485

1486
1487
1488
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1489

1490
1491
        return processor_data, passthrough_data

1492
1493
1494
    def _call_hf_processor(
        self,
        prompt: str,
1495
1496
1497
1498
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1499
        tok_kwargs: Mapping[str, object],
1500
    ) -> BatchFeature:
1501
1502
1503
1504
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1505
1506
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1507
            dict(text=prompt, **mm_data),
1508
            dict(**mm_kwargs, **tok_kwargs),
1509
1510
        )

1511
    def _hf_processor_applies_updates(
1512
1513
1514
1515
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1516
        tokenization_kwargs: Mapping[str, object],
1517
1518
    ) -> bool:
        """
1519
        Return whether the HF processor applies prompt updates.
1520

1521
1522
        For most HF processors, this should be `True` when multi-modal
        data items are passed, but `False` when multi-modal embeddings
1523
1524
1525
1526
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
1527
1528
            for items in mm_items.values()
        )
1529

1530
    def _apply_hf_processor_text_mm(
1531
        self,
1532
        prompt_text: str,
1533
        mm_items: MultiModalDataItems,
1534
        hf_processor_mm_kwargs: Mapping[str, object],
1535
        tokenization_kwargs: Mapping[str, object],
1536
    ) -> tuple[list[int], BatchFeature, bool]:
1537
        """
1538
1539
        Apply the HF processor on the prompt text and multi-modal data
        together.
1540

1541
        In addition, return whether prompt updates have been applied.
1542
1543
1544
1545
1546
1547
1548
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
1549
            tok_kwargs=tokenization_kwargs,
1550
1551
        )
        processed_data.update(passthrough_data)
1552

1553
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
1554

1555
        is_update_applied = self._hf_processor_applies_updates(
1556
1557
1558
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1559
            tokenization_kwargs=tokenization_kwargs,
1560
1561
        )

1562
        return prompt_ids, processed_data, is_update_applied
1563

1564
    def _apply_hf_processor_text_only(
1565
1566
1567
1568
        self,
        prompt_text: str,
        tokenization_kwargs: Mapping[str, object],
    ) -> list[int]:
1569
        """
1570
        Apply the HF processor on the prompt text only.
1571

1572
1573
1574
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1575
        """
1576
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1577
1578
1579
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
1580
            tokenization_kwargs=tokenization_kwargs,
1581
1582
        )

1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
1595
1596
1597
        with the output of
        [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
        on the
1598
1599
1600
1601
1602
1603
1604
1605
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1606
        tokenization_kwargs: Mapping[str, object],
1607
    ) -> BatchFeature:
1608
1609
1610
1611
1612
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
1613
1614
        [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
        to go along with the multi-modal data.
1615
1616
1617
        """
        mm_counts = mm_items.get_all_counts()

1618
        _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
1619
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1620
1621
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1622
            tokenization_kwargs=tokenization_kwargs,
1623
1624
        )

1625
        return mm_processed_data
1626
1627
1628

    def _apply_hf_processor_main(
        self,
1629
        prompt: str | list[int],
1630
1631
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1632
        tokenization_kwargs: Mapping[str, object],
1633
        *,
1634
        enable_hf_prompt_update: bool,
1635
    ) -> tuple[list[int], BatchFeature, bool]:
1636
1637
1638
        """
        Apply the HF processor on the prompt text and multi-modal data.

1639
        In addition, return whether prompt updates have been applied
1640
        (for most HF processors, this should be `True`).
1641

1642
        Note:
1643
            If `enable_hf_prompt_update=False`, we use HF processor
1644
            to perform prompt updates if available; HF processor requires
1645
            that the prompt corresponds to multi-modal items.
1646
1647
        """
        if isinstance(prompt, str):
1648
            if enable_hf_prompt_update:
1649
1650
1651
1652
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1653
                    tokenization_kwargs=tokenization_kwargs,
1654
1655
                )

1656
            prompt_ids = self._apply_hf_processor_text_only(prompt, tokenization_kwargs)
1657
1658
1659
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1660
        mm_processed_data = self._apply_hf_processor_mm_only(
1661
            mm_items=mm_items,
1662
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1663
            tokenization_kwargs=tokenization_kwargs,
1664
1665
        )

1666
        return prompt_ids, mm_processed_data, False
1667

1668
    def _hash_mm_items(
1669
1670
1671
1672
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
1673
        *,
1674
        mm_uuids: MultiModalUUIDDict | None = None,
1675
    ) -> MultiModalHashes:
1676
        """Create MM hashes to be returned.
1677

1678

1679
1680
1681
        Note: When overrides are provided via callers of `apply`,
        `_hash_mm_items` will be bypassed and the overrides will be used.
        """
1682
1683
        model_id = self.info.model_id

1684
        hashes: MultiModalHashes = {}
1685
        mm_uuids = mm_uuids or {}
1686
1687

        for modality, items in mm_items.items():
1688
1689
1690
1691
            if modality in mm_uuids:
                mm_uuids_per_modality = mm_uuids[modality]
                if isinstance(mm_uuids_per_modality, str):
                    mm_uuids_per_modality = [mm_uuids_per_modality]
1692
1693
1694

                # For None entries, compute a hash; otherwise, use provided ID.
                computed: list[str] = []
1695
                for i, item in enumerate(items.get_all_items_for_hash()):
1696
                    item_uuid = mm_uuids_per_modality[i]
1697

1698
                    # NOTE: Even if a item_uuid is provided, we still compute a
1699
1700
1701
                    # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
                    # are provided. This is because the processed multimodal
                    # inputs can be different depending on the processor kwargs.
1702
1703
1704
1705
1706
                    if (
                        item_uuid is None
                        or hf_processor_mm_kwargs
                        or tokenization_kwargs
                    ):
1707
1708
                        # NOTE: use provided hash string to hash with kwargs
                        # if available for better performance.
1709
                        item = item_uuid if item_uuid is not None else item
1710
1711
1712
1713
1714
                        computed.append(
                            MultiModalHasher.hash_kwargs(
                                model_id=model_id,
                                **{modality: item},
                                **hf_processor_mm_kwargs,
1715
1716
1717
                                **tokenization_kwargs,
                            )
                        )
1718
                    else:
1719
                        computed.append(item_uuid)
1720
1721
1722
                hashes[modality] = computed
            else:
                hashes[modality] = [
1723
1724
1725
1726
1727
1728
                    MultiModalHasher.hash_kwargs(
                        model_id=model_id,
                        **{modality: item},
                        **hf_processor_mm_kwargs,
                        **tokenization_kwargs,
                    )
1729
1730
1731
1732
                    for item in items
                ]

        return hashes
1733

1734
1735
    def _get_cache_missing_items(
        self,
1736
        cache: BaseMultiModalProcessorCache,
1737
1738
        mm_data_items: MultiModalDataItems,
        mm_hashes: MultiModalHashes,
1739
    ) -> tuple[MultiModalIsCached, MultiModalDataItems]:
1740
        mm_is_cached = {
1741
            modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
1742
1743
1744
1745
        }

        mm_missing_idxs = {
            modality: [
1746
1747
                idx
                for idx, item_is_cached in enumerate(items_is_cached)
1748
1749
1750
1751
                if not item_is_cached
            ]
            for modality, items_is_cached in mm_is_cached.items()
        }
1752
1753
1754
1755
1756
1757
1758
1759
        mm_missing_data = {}
        for modality, idxs in mm_missing_idxs.items():
            missing_modality_data = []
            for idx in idxs:
                data = mm_data_items[modality][idx]
                if data is None:
                    raise ValueError(
                        f"Cache miss for {modality} at index {idx} "
1760
1761
                        f"but data is not provided."
                    )
1762
1763
1764
                else:
                    missing_modality_data.append(data)
            mm_missing_data[modality] = missing_modality_data
1765

1766
        return mm_is_cached, self._to_mm_items(mm_missing_data)
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778

    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        """
        Override this if other attributes of `ResolvedPromptUpdate`
        also need to be recomputed after retrieving from the cache.
        """
        return replace(cached_update, item_idx=new_item_idx)

1779
1780
    def _merge_mm_kwargs(
        self,
1781
        cache: BaseMultiModalProcessorCache,
1782
        mm_hashes: MultiModalHashes,
1783
        mm_is_cached: MultiModalIsCached,
1784
        mm_missing_kwargs: MultiModalKwargsItems,
1785
1786
        mm_missing_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
1787
1788
1789
1790
1791
        # Need to touch all mm hashes before update to avoid hash in updated
        # list evict during update
        for hashes in mm_hashes.values():
            for item_hash in hashes:
                cache.touch_sender_cache_item(item_hash)
1792

1793
        mm_missing_next_idx = defaultdict[str, int](lambda: 0)
1794

1795
        merged_kwargs = defaultdict[str, list[MultiModalKwargsItem | None]](list)
1796
1797
1798
        merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](
            list
        )
1799
1800
        for modality, hashes in mm_hashes.items():
            missing_kwargs = mm_missing_kwargs.get(modality, [])
1801
            missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
1802
1803
1804
1805

            for item_idx, item_hash in enumerate(hashes):
                if not mm_is_cached[modality][item_idx]:
                    missing_next_idx = mm_missing_next_idx[modality]
1806
1807
                    missing_kwargs_item = missing_kwargs[missing_next_idx]
                    missing_updates_item = missing_prompt_updates[missing_next_idx]
1808

1809
                    mm_missing_next_idx[modality] += 1
1810

1811
                    item = missing_kwargs_item, missing_updates_item
1812
                else:
1813
1814
1815
1816
1817
                    item = None

                kwargs, updates = cache.get_and_update_item(item, item_hash)

                merged_kwargs[modality].append(kwargs)
1818
1819
1820
1821
1822
1823
                merged_prompt_updates[modality].append(
                    [
                        self._recompute_cached_prompt_update(update, item_idx)
                        for update in updates
                    ]
                )
1824

1825
1826
        mm_kwargs = MultiModalKwargsItems(merged_kwargs)
        mm_prompt_updates = dict(merged_prompt_updates)
1827

1828
        return mm_kwargs, mm_prompt_updates
1829
1830
1831

    def _apply_hf_processor(
        self,
1832
        prompt: str | list[int],
1833
1834
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1835
        tokenization_kwargs: Mapping[str, object],
1836
        *,
1837
        mm_uuids: MultiModalUUIDDict | None = None,
1838
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1839
1840
        (
            prompt_ids,
1841
            mm_processed_data,
1842
1843
1844
1845
1846
            is_update_applied,
        ) = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1847
            tokenization_kwargs=tokenization_kwargs,
1848
1849
1850
            enable_hf_prompt_update=True,
        )

1851
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
1852
            mm_processed_data,
1853
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
1854
1855
        )

1856
        # Use overrides if provided; fallback to data-dependent hashing.
1857
1858
1859
1860
1861
1862
        mm_hashes = self._hash_mm_items(
            mm_data_items,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
1863

1864
        mm_prompt_updates = self._get_mm_prompt_updates(
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
            mm_data_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
            hashes=mm_hashes,
            prompt_updates=mm_prompt_updates,
        )

        return prompt_ids, mm_info, is_update_applied
1877

1878
1879
    def _cached_apply_hf_processor(
        self,
1880
        prompt: str | list[int],
1881
1882
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1883
        tokenization_kwargs: Mapping[str, object],
1884
        *,
1885
        mm_uuids: MultiModalUUIDDict | None = None,
1886
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1887
1888
1889
1890
1891
1892
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache

1893
1894
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1895
            return self._apply_hf_processor(
1896
                prompt=prompt,
1897
                mm_data_items=mm_data_items,
1898
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1899
                tokenization_kwargs=tokenization_kwargs,
1900
                mm_uuids=mm_uuids,
1901
1902
            )

1903
1904
1905
1906
1907
1908
        mm_hashes = self._hash_mm_items(
            mm_data_items,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
1909

1910
        mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
1911
1912
            cache=cache,
            mm_data_items=mm_data_items,
1913
            mm_hashes=mm_hashes,
1914
        )
1915

1916
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1917
        # so we can't apply prompt updates until the new multimodal
1918
1919
1920
        # items are combined with the cached multimodal items
        (
            prompt_ids,
1921
            mm_missing_processed_data,
1922
            is_update_applied,
1923
        ) = self._apply_hf_processor_main(
1924
            prompt=prompt,
1925
            mm_items=mm_missing_data_items,
1926
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1927
            tokenization_kwargs=tokenization_kwargs,
1928
            enable_hf_prompt_update=False,
1929
1930
        )

1931
        mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
1932
            mm_missing_processed_data,
1933
1934
1935
            self._get_mm_fields_config(
                mm_missing_processed_data, hf_processor_mm_kwargs
            ),
1936
1937
        )

1938
1939
1940
1941
        mm_missing_prompt_updates = self._get_mm_prompt_updates(
            mm_missing_data_items,
            hf_processor_mm_kwargs,
            mm_missing_kwargs,
1942
        )
1943

1944
1945
1946
        mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
            cache,
            mm_hashes=mm_hashes,
1947
            mm_is_cached=mm_is_cached,
1948
1949
            mm_missing_kwargs=mm_missing_kwargs,
            mm_missing_prompt_updates=mm_missing_prompt_updates,
1950
1951
1952
1953
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
1954
            hashes=mm_hashes,
1955
1956
            prompt_updates=mm_prompt_updates,
        )
1957

1958
        return prompt_ids, mm_info, is_update_applied
1959

1960
1961
1962
    def _apply_token_matches(
        self,
        prompt: list[int],
1963
1964
1965
1966
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
1967
1968
1969
1970

    def _apply_text_matches(
        self,
        prompt: str,
1971
1972
1973
1974
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
1975

1976
    def _apply_prompt_updates(
1977
1978
        self,
        token_ids: list[int],
1979
        mm_prompt_updates: MultiModalPromptUpdates,
1980
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
1981
        tokenizer = self.info.get_tokenizer()
1982

1983
1984
1985
1986
        new_token_ids, match_result = self._apply_token_matches(
            token_ids,
            mm_prompt_updates,
        )
1987
1988
1989
1990
1991
1992
1993
1994
1995

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1996
1997
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1998
        if not all(
1999
2000
2001
            all(update_idx is not None for update_idx in update_idxs)
            for update_idxs in match_result.values()
        ):
2002
            new_text, match_result = self._apply_text_matches(
2003
                _seq2text(tokenizer, token_ids, use_cache=False),
2004
                mm_prompt_updates,
2005
2006
            )

2007
            new_token_ids = _seq2tokens(tokenizer, new_text, use_cache=False)
2008

2009
        matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
2010
2011
2012
2013
        for modality, update_idxs in match_result.items():
            for item_idx, update_idx in enumerate(update_idxs):
                assert update_idx is not None, (
                    "Failed to apply prompt replacement for "
2014
2015
                    f"mm_items[{modality!r}][{item_idx}]"
                )
2016
2017

                matched_updates[modality].append(
2018
2019
                    [mm_prompt_updates[modality][item_idx][update_idx]]
                )
2020
2021

        placeholders = self._find_mm_placeholders(
2022
2023
            new_token_ids,
            dict(matched_updates),
2024
        )
2025

2026
        return new_token_ids, placeholders
2027

2028
2029
    def _validate_mm_kwargs(
        self,
2030
        mm_kwargs: MultiModalKwargsOptionalItems,
2031
2032
2033
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
2034
            items = mm_kwargs.get(modality, [])
2035
2036
2037
2038
2039
2040
2041
2042
2043

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
2044
2045
                    "`_call_hf_processor` and `_get_mm_fields_config`)."
                )
2046

2047
    def _validate_mm_updates(
2048
        self,
2049
        mm_updates: MultiModalPromptUpdates,
2050
        mm_item_counts: Mapping[str, int],
2051
    ) -> None:
2052
        for modality, item_count in mm_item_counts.items():
2053
            placeholders = mm_updates.get(modality, [])
2054

2055
            if len(placeholders) != item_count:
2056
                raise RuntimeError(
2057
                    f"Expected there to be {item_count} prompt updates "
2058
                    f"corresponding to {item_count} {modality} items, but "
2059
                    f"instead found {len(placeholders)} prompt updates! "
2060
2061
2062
                    "This is likely because you forgot to include input "
                    "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
                    "in the prompt. If the model has a chat template, make "
2063
2064
                    "sure you have applied it before calling `LLM.generate`."
                )
2065

2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt placeholders "
                    f"corresponding to {item_count} {modality} items, but "
                    f"instead found {len(placeholders)} prompt placeholders! "
                    "Make sure the implementation of `_call_hf_processor` and "
2080
2081
                    "`_get_mm_fields_config` are consistent with each other."
                )
2082

2083
2084
2085
2086
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
2087
        mm_kwargs: MultiModalKwargsOptionalItems,
2088
        mm_prompt_updates: MultiModalPromptUpdates,
2089
        is_update_applied: bool,
2090
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
2091
        mm_item_counts = mm_items.get_all_counts()
2092
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
2093
        self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
2094

2095
        if is_update_applied:
2096
2097
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
2098
                mm_prompt_updates,
2099
            )
2100
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
2101
        else:
2102
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
2103
                prompt_ids,
2104
                mm_prompt_updates,
2105
            )
2106
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
2107

2108
        return prompt_ids, mm_placeholders
2109
2110
2111

    def apply(
        self,
2112
        prompt: str | list[int],
2113
2114
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
2115
        tokenization_kwargs: Mapping[str, object] | None = None,
2116
        *,
2117
        mm_uuids: MultiModalUUIDDict | None = None,
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        mm_items = self._to_mm_items(mm_data)

2134
2135
2136
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

2137
2138
        (
            prompt_ids,
2139
            mm_info,
2140
2141
2142
2143
2144
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
2145
            tokenization_kwargs=tokenization_kwargs,
2146
            mm_uuids=mm_uuids,
2147
2148
        )

2149
        # NOTE: tokenization_kwargs are not required to init processor
2150
        prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
2151
2152
            mm_items=mm_items,
            prompt_ids=prompt_ids,
2153
2154
            mm_kwargs=mm_info.kwargs,
            mm_prompt_updates=mm_info.prompt_updates,
2155
2156
2157
            is_update_applied=is_update_applied,
        )

2158
2159
2160
2161
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
2162

2163
        return MultiModalInputs(
2164
            type="multimodal",
2165
            prompt_token_ids=prompt_ids,
2166
2167
            mm_kwargs=mm_info.kwargs,
            mm_hashes=mm_info.hashes,
2168
            mm_placeholders=mm_placeholder_ranges,
2169
        )
2170
2171
2172
2173
2174
2175


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
    @abstractmethod
    def create_encoder_prompt(
        self,
2176
        prompt: str | list[int],
2177
        mm_data: MultiModalDataDict,
2178
    ) -> str | list[int]:
2179
        """
2180
        Create input prompt for the encoder. HF processor will be applied on
2181
2182
        this prompt during profiling and generation.
        """
2183
2184
        raise NotImplementedError

2185
2186
2187
2188
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

2189
2190
    def create_decoder_prompt(
        self,
2191
        prompt: str | list[int],
2192
        mm_data: MultiModalDataDict,
2193
    ) -> str | list[int]:
2194
2195
2196
        """Create input prompt for the decoder."""
        return prompt

2197
    def _get_enc_dec_inputs(
2198
        self,
2199
        prompt: str | list[int],
2200
        mm_data: MultiModalDataDict,
2201
2202
        encoder_inputs: MultiModalInputs,
    ):
2203
        tokenizer = self.info.get_tokenizer()
2204
2205
        decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt_raw, str):
2206
2207
            decoder_prompt_ids = tokenizer.encode(
                decoder_prompt_raw, add_special_tokens=False
2208
            )
2209
        else:
2210
            decoder_prompt_ids = decoder_prompt_raw
2211
2212
2213

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
2214
2215
            **encoder_inputs,
        )
2216
        mm_inputs["prompt_token_ids"] = decoder_prompt_ids
2217
        return mm_inputs
2218
2219
2220

    def apply(
        self,
2221
        prompt: str | list[int],
2222
2223
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
2224
        tokenization_kwargs: Mapping[str, object] | None = None,
2225
        *,
2226
        mm_uuids: MultiModalUUIDDict | None = None,
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
2240
            tokenization_kwargs,
2241
            mm_uuids=mm_uuids,
2242
2243
2244
2245
2246
2247
2248
        )

        return self._get_enc_dec_inputs(
            prompt=prompt,
            mm_data=mm_data,
            encoder_inputs=encoder_inputs,
        )