"vllm/vscode:/vscode.git/clone" did not exist on "cf349c4a97adb36354bdc2b14448ea55279d1575"
processing.py 60.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections import defaultdict
5
6
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
7
from dataclasses import dataclass, field
8
from enum import Enum
9
from functools import lru_cache
10
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
11
                    TypeVar, Union, cast)
12

13
import regex as re
14
import torch
15
from typing_extensions import assert_never
16

17
from vllm.inputs import InputProcessingContext
18
from vllm.logger import init_logger
19
20
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
21
from vllm.utils import flatten_2d_lists, full_groupby
22

23
from .cache import MultiModalCache
24
from .hasher import MultiModalHasher
25
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
26
27
28
                     MultiModalFieldConfig, MultiModalInputs,
                     MultiModalKwargsItem, MultiModalKwargsItems,
                     PlaceholderRange)
29
30
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
31
32

if TYPE_CHECKING:
33
34
35
36
    from transformers.configuration_utils import PretrainedConfig
    from transformers.feature_extraction_utils import BatchFeature
    from transformers.processing_utils import ProcessorMixin

37
    from .profiling import BaseDummyInputsBuilder
38

39
logger = init_logger(__name__)
40
41

_S = TypeVar("_S", str, list[int])
42
43
44

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
45

46

47
48
49
50
51
52
53
54
55
56
57
class _GetMatchIndex(Protocol):

    def __call__(
        self,
        tokenizer: AnyTokenizer,
        prompt: PromptSeq,
        start_idx: int = 0,
    ) -> Optional[int]:
        ...


58
59
60
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
61
    get_match_index: _GetMatchIndex
62
63
64
65
66
67
68
69
70
71
72


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
73
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
74
75
76
77
78
79
80
81
82
83

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
84
            start_idx: int = 0,
85
        ) -> Optional[int]:
86
87
88
            if start_idx != 0:
                return None

89
90
91
92
93
94
95
96
97
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
98
99
100
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
101
102
103
104
105
106
107
108
109
110
111
112
113

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
114
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
115
116


117
UpdateTarget = Union[PromptSeq, PromptIndex]
118
119
120
121
"""
The token sequence or text to update.
"""

122
123
124
125
126
127
128
129
130
131
PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget]
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""

132

133
@dataclass
134
class PromptUpdateDetails(Generic[_S]):
135
    """Details about the token sequence or text that are part of the update."""
136

137
    full: _S
138
    """The full content."""
139

140
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
141
    """
142
143
144
    Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
    return a boolean mask of shape `(len(full),)` indicating which positions
    of `full` to assign embeddings to.
145
146
147
148

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
149
    [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
150
151
152
    """

    @staticmethod
153
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
181
182


183
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
184
"""
185
The token sequence or text that are part of the update.
186

187
If only part of the content corresponds to feature placeholders, you can
188
189
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
190
"""
191

192
193
194
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
195
196
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
197
198
199
200
201
202
203
204
205
206
207
208
209
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
210
class PromptUpdate(ABC):
211
212
213
214
215
216
217
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

218
    target: PromptUpdateTarget
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    def _resolve_target(
        self,
        tokenizer: AnyTokenizer,
        item_idx: int,
    ) -> Union["_BoundPromptSequence", PromptIndex]:
        target = self.target
        if callable(target):
            target = target(item_idx)

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(tokenizer, target)

    def _resolve_content(
        self,
        tokenizer: AnyTokenizer,
        item_idx: int,
    ) -> "_BoundPromptContent":
        content = self.content
        if callable(content):
            content = content(item_idx)

        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

        bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full)
        bound_content = _BoundPromptContent(full=bound_full,
                                            is_embed=content.is_embed)

        return bound_content

    def resolve(
        self,
        tokenizer: AnyTokenizer,
        item_idx: int,
    ) -> "ResolvedPromptUpdate":
        """
        Given the index of the processed item within
        [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
        output a copy of this object with its lazy attributes resolved.
        """
        return ResolvedPromptUpdate(
            modality=self.modality,
            item_idx=item_idx,
            mode=self.mode,
            target=self._resolve_target(tokenizer, item_idx),
            content=self._resolve_content(tokenizer, item_idx),
281
282
        )

283

284
@dataclass
285
286
287
288
289
290
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    For each image, insert a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder after the ``<s>`` token:

    ```python
    PromptInsertion(
        modality="image",
        target="<s>",
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the start of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.start(),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens after a prefix ``Images:``:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.prefix("Images:"),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the end of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.end(),
        insertion="<image>" * image_feature_size,
    )
    ```
331
332
333
334
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
335
336
337
338
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to insert right after
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
355
356
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
357
358
359

    Example:

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    For each image, replace one ``<image>`` input placeholder in the prompt
    with a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement="<image>" * image_feature_size,
    )
    ```

    As above, but further pad the feature placeholders with ``<image_bos>``
    and `<image_eos>``, which are not supposed to be passed to the vision
    encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement=PromptUpdateDetails(
            full="".join([
                "<image_bos>",
                "<image>" * image_feature_size,
                "<image_eos>",
            ]),
            features="<image>" * image_feature_size,
        ),
    )
    ```

    To avoid unnecessary tokenization during prompt replacement,
    we recommended passing token sequences instead of text:

    ```python
    PromptReplacement(
        modality="image",
        target=[image_token_id],
        replacement=PromptUpdateDetails(
            full=([image_bos_id] + [image_token_id] * image_feature_size
                    + [image_eos_id]),
            features=[image_token_id] * image_feature_size,
        ),
    )
    ```
405
406
    """

407
    replacement: PromptUpdateContent = field(repr=False)
408
    """
409
410
411
412
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to replace
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
413

414
415
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
416
417
    """

418
419
420
421
422
423
424
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
425
426


427
428
429
430
431
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
432
    add_special_tokens: Optional[bool] = None,
433
) -> list[int]:
434
435
436
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
437
438


439
440
441
442
443
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
444
    skip_special_tokens: Optional[bool] = None,
445
) -> str:
446
447
448
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
449
450
451
452
453


class _HasModalityAttr(Protocol):
    modality: str

454

455
class _HasModalityProp(Protocol):
456

457
458
459
460
461
462
463
464
465
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
466
467
    """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
    based on modality."""
468
469
470
471
472
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
473
    """
474
475
    A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
    to a tokenizer to automatically
476
477
    convert between token sequence and text representations.
    """
478
479
    tokenizer: AnyTokenizer = field(repr=False)

480
481
482
    _text: Optional[str]
    _token_ids: Optional[list[int]]

483
    @staticmethod
484
485
    def from_seq(
        tokenizer: AnyTokenizer,
486
        seq: PromptSeq,
487
    ) -> "_BoundPromptSequence":
488
489
490
491
492
493
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
511
512
513
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
514
515
516
517

        return self._token_ids


518
@dataclass
519
class _BoundPromptContent:
520
    full: _BoundPromptSequence
521
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
522
523


524
525
526
527
528
529
530
class PromptTargetMatch(NamedTuple):
    start_idx: int
    end_idx: int


@dataclass(frozen=True)
class ResolvedPromptUpdate:
531
    """
532
533
    A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
    lazy attributes resolved, apart from those related to tokenization.
534
    """
535

536
537
    modality: str
    """The modality for which the update is made."""
538

539
540
    item_idx: int
    """The index within `modality` of the item this update pertains to."""
541

542
543
    mode: UpdateMode
    """Defines how to update the prompt."""
544

545
546
    target: Union[_BoundPromptSequence, PromptIndex]
    """The token sequence (or text) to update."""
547

548
549
    content: _BoundPromptContent = field(repr=False)
    """The placeholder tokens that are part of the update."""
550

551
552
553
554
555
556
557
558
559
    def iter_token_matches(
        self,
        prompt: list[int],
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
560

561
562
563
564
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
565

566
            return
567

568
569
570
571
        for match in iter_token_matches(prompt,
                                        target.token_ids,
                                        start_idx=start_idx):
            yield PromptTargetMatch(match.start_idx, match.end_idx)
572

573
574
575
576
577
578
579
580
581
    def iter_text_matches(
        self,
        prompt: str,
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
582

583
584
585
586
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
587

588
            return
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        for match in re.finditer(re.escape(target.text), prompt,
                                 pos=start_idx):
            yield PromptTargetMatch(match.start(), match.end())

    def iter_matches(
        self,
        prompt: Union[list[int], str],
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        if isinstance(prompt, str):
            return self.iter_text_matches(prompt,
                                          tokenizer,
                                          start_idx=start_idx)

        return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
608
609


610
611
612
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
613
614


615
616
617
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
618
619
    *,
    start_idx: int = 0,
620
) -> Generator[_TokenMatch]:
621
    """
622
    Yield each occurrence of `match_ids` in `token_ids`.
623
624
625
626

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
627
    match_len = len(match_ids)
628

629
630
    if match_len == 0:
        return
631

632
    while start_idx < prompt_len - match_len + 1:
633
        end_idx = start_idx + match_len
634

635
636
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
637
638
639
640
641

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
642
643


644
645
646
647
648
649
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
650
651
    Replace each occurrence of `match_ids` in `token_ids`
    with `new_ids`.
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


671
@dataclass
672
class PlaceholderFeaturesInfo:
673
    modality: str
674
    item_idx: int
675
    start_idx: int
676
    tokens: list[int]
677
    is_embed: Optional[torch.Tensor]
678
679
680

    @property
    def length(self) -> int:
681
        return len(self.tokens)
682
683

    def to_range(self) -> PlaceholderRange:
684
685
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
686
687
688
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
689
            is_embed=self.is_embed,
690
        )
691
692


693
_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
694
695


696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def _find_matches(
    prompt: _S,
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
    *,
    prev_end_idx: int = 0,
    current_result: "MultiModalPromptUpdatesApplyResult",
) -> tuple[Optional[UpdateMode], list[_MatchToApply]]:
    mode: Optional[UpdateMode] = None
    mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()

    for modality, modality_updates in mm_prompt_updates.items():
        for item_idx, item_updates in enumerate(modality_updates):
            if current_result[modality][item_idx] is not None:
                continue  # Updates have already been applied for this item

            for update_idx, update in enumerate(item_updates):
                if (modality, item_idx) in mm_matches:
                    break  # Already found a match for this item

                for match in update.iter_matches(
                        prompt,
                        tokenizer,
                        start_idx=prev_end_idx,
                ):
                    # All matches should share the same mode
                    if mode is None:
                        mode = update.mode
                    elif mode != update.mode:
                        continue

                    mm_matches[(modality, item_idx)] = match, update_idx
                    break  # Get only the first valid match per item

    # Prioritize earlier matches
    matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])

    # To avoid conflicts, only replace one non-empty item at a time
    if mode == UpdateMode.REPLACE:
        matches_to_apply_ = list[_MatchToApply]()
        has_non_empty_matches = False

        for item in matches_to_apply:
            _, (match, _) = item
            if match.start_idx == match.end_idx:
                matches_to_apply_.append(item)
            elif not has_non_empty_matches:
                has_non_empty_matches = True
                matches_to_apply_.append(item)

        matches_to_apply = matches_to_apply_

    return mode, matches_to_apply
749
750


751
def _apply_matches(
752
    prompt: _S,
753
754
755
756
757
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
    prompt_len = len(prompt)

758
    out_seqs = list[Union[str, list[int]]]()
759
760
761
762
    out_result: MultiModalPromptUpdatesApplyResult = {
        m: [None] * len(items)
        for m, items in mm_prompt_updates.items()
    }
763

764
765
766
    start_idx = prev_end_idx = 0
    while start_idx < max(prompt_len, 1):  # Allow inserts into empty prompt
        found = False
767

768
769
770
771
772
773
774
        mode, matches_to_apply = _find_matches(
            prompt,
            mm_prompt_updates,
            tokenizer,
            prev_end_idx=prev_end_idx,
            current_result=out_result,
        )
775

776
777
778
        if mode is not None:
            for (modality, item_idx), (match, update_idx) in matches_to_apply:
                found = True
779

780
781
782
                matched_update = mm_prompt_updates[modality][item_idx][
                    update_idx]
                matched_content = matched_update.content
783

784
785
786
787
788
789
                if mode == UpdateMode.INSERT:
                    end_idx_to_insert = match.end_idx
                elif mode == UpdateMode.REPLACE:
                    end_idx_to_insert = match.start_idx
                else:
                    assert_never(mode)
790

791
792
793
794
                out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
                out_seqs.append(matched_content.full.text if isinstance(
                    prompt, str) else matched_content.full.token_ids)
                out_result[modality][item_idx] = update_idx
795

796
797
798
799
800
                # Exclude overlapping matches
                start_idx = prev_end_idx = match.end_idx

        if not found:
            start_idx += 1
801
802
803

    out_seqs.append(prompt[prev_end_idx:])

804
    return cast(list[_S], out_seqs), out_result
805
806


807
def apply_token_matches(
808
    prompt: list[int],
809
810
811
812
813
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
814

815
816
817
818
819
820
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
    token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates,
                                           tokenizer)
821

822
    return flatten_2d_lists(token_id_seqs), result
823
824


825
def apply_text_matches(
826
    prompt: str,
827
828
829
830
831
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
832

833
834
835
836
837
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
    texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
838

839
    return "".join(texts), result
840
841


842
def _iter_placeholders(
843
    prompt: list[int],
844
    mm_prompt_updates: "MultiModalPromptUpdates",
845
) -> Iterable[PlaceholderFeaturesInfo]:
846
    """
847
    Yield each set of placeholder tokens found in `prompt`.
848
849
850

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
851
    appears earlier in `mm_prompt_updates` takes priority.
852

853
854
    Note that empty matches are ignored.
    """
855
    prompt_len = len(prompt)
856
857
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}

858
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
859
860
861
862
863

    start_idx = 0
    while start_idx < prompt_len:
        found = False

864
        for modality, modality_updates in mm_prompt_updates.items():
865
866
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
867
                continue
868

869
870
            for update in modality_updates[item_idx]:
                content = update.content
871
872
873
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
874

875
                if content_len_full == 0 or end_idx_full > prompt_len:
876
877
                    continue

878
                if prompt[start_idx:end_idx_full] == content_tokens_full:
879
880
881
882
883
884
885
886
887
888
889
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
890

891
                    # Exclude overlapping matches
892
                    start_idx = end_idx_full
893
894
895
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
896

897
898
            if found:
                break  # Go back to the outer while loop
899
900
901

        if not found:
            start_idx += 1
902
903


904
905
def find_mm_placeholders(
    prompt: list[int],
906
    mm_prompt_updates: "MultiModalPromptUpdates",
907
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
908
    it = _iter_placeholders(prompt, mm_prompt_updates)
909
910
911
    return dict(full_groupby_modality(it))


912
class ProcessingCache(MultiModalCache):
913

914
    def __init__(self, capacity_gb: float) -> None:
915
916
        super().__init__()

917
        self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
918

919
920
921
        self.get = self._cache.get
        self.put = self._cache.put
        self.reset = self._cache.clear
922

923

924
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
925

926

927
class BaseProcessingInfo:
928
    """Base class to provide the information necessary for data processing."""
929

930
931
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
932

933
934
935
936
937
938
939
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
940
941
        return self.ctx.tokenizer

942
    def get_hf_config(self) -> "PretrainedConfig":
943
944
        return self.ctx.get_hf_config()

945
    def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
946
947
948
949
950
951
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

952
953
954
955
956
957
958
959
960
961
962
963
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

964
965
966
967
968
969
970
971
972
973
974
975
976
977
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

            allowed_limits[modality] = (user_limit if supported_limit is None
                                        else min(user_limit, supported_limit))

        return allowed_limits

978
979
980
981
982
983
984
985
986
987
988
989
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Optional[Mapping[str, int]]:
        """
        Return the maximum number of tokens per item of for each modality.
        
        When `None` (the default) is returned, vLLM will generate dummy inputs
        (images/videos) at maximum possible sizes and process them to determine
        the maximum token count per modality.

990
991
992
993
994
        This approach works but can be very slow for certain models (e.g.,
        Qwen2.5-VL), leading to very long startup time. For better performance,
        each model can override this method to return pre-computed maximum token
        counts, avoiding the need for dummy input generation and processing.

995
996
997
998
999
1000
        Note:
            The maximum number of tokens per item of each modality returned 
            from this function should respect the model's maximum sequence
            length and the maximum number of items of each modality allowed,
            and agree with dummy inputs (images/videos) at maximum possible
            sizes.
1001
1002
1003
        """
        return None

1004
1005

_I = TypeVar("_I", bound=BaseProcessingInfo)
1006

1007
1008
MultiModalHashes = dict[str, list[str]]
"""
1009
A collection of hashes with a similar structure as
1010
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
1011
1012
"""

1013
MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
1014
1015
1016
1017
1018
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""

1019
1020
1021
1022
1023
1024
1025
1026
MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]]
"""
For an item `MultiModalPromptUpdates[k][i]`,
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
`ResolvedPromptUpdate` instances have been applied.
"""

1027
1028
1029

class MultiModalProcessingInfo(NamedTuple):
    kwargs: MultiModalKwargsItems
1030
    hashes: MultiModalHashes
1031
1032
    prompt_updates: MultiModalPromptUpdates

1033
1034

class BaseMultiModalProcessor(ABC, Generic[_I]):
1035
    """
1036
    Abstract base class to process multi-modal inputs to be used in vLLM.
1037

1038
    Not to be confused with `transformers.ProcessorMixin`.
1039
1040
    """

1041
    def __init__(self,
1042
1043
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1044
                 *,
1045
                 cache: Optional[ProcessingCache] = None) -> None:
1046
1047
        super().__init__()

1048
1049
        self.info = info
        self.dummy_inputs = dummy_inputs
1050
        self.cache = cache
1051

1052
1053
        self.data_parser = self._get_data_parser()

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
        # Avoid unnecessary recomputation
        self._supported_mm_limits = self.info.get_supported_mm_limits()
        self._allowed_mm_limits = self.info.get_allowed_mm_limits()

    @property
    def supported_mm_limits(self):
        return self._supported_mm_limits

    @property
    def allowed_mm_limits(self):
        return self._allowed_mm_limits

1066
    def __call__(
1067
        self,
1068
1069
        prompt: str,
        mm_data: MultiModalDataDict,
1070
        hf_processor_mm_kwargs: Mapping[str, object],
1071
    ) -> MultiModalInputs:
1072
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1073

1074
1075
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1076
        Construct a parser to preprocess multi-modal data items
1077
1078
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1079
1080

        You can support additional modalities by creating a subclass
1081
1082
        of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
        that has additional subparsers.
1083
1084
1085
        """
        return MultiModalDataParser()

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
    def validate_num_items(
        self,
        modality: str,
        num_items: int,
    ) -> None:
        supported_limit = self.supported_mm_limits.get(modality, 0)
        allowed_limit = self.allowed_mm_limits.get(modality, 0)

        if supported_limit is None:
            supported_limit = allowed_limit

        limit = min(supported_limit, allowed_limit)

        if num_items > limit:
            msg = (f"At most {limit} {modality}(s) may be provided in "
                   "one prompt.")

            if num_items <= supported_limit:
                msg += " Set `--limit-mm-per-prompt` to increase this limit."

            raise ValueError(msg)

1108
    def _to_mm_items(
1109
1110
1111
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1112
        """
1113
1114
1115
1116
1117
        Normalize
        [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
        to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1118
        """
1119
        mm_items = self.data_parser.parse_mm_data(mm_data)
1120
1121

        for modality, items in mm_items.items():
1122
            self.validate_num_items(modality, len(items))
1123
1124

        return mm_items
1125

1126
1127
1128
    @abstractmethod
    def _get_mm_fields_config(
        self,
1129
        hf_inputs: "BatchFeature",
1130
1131
1132
1133
1134
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1135
    @abstractmethod
1136
    def _get_prompt_updates(
1137
        self,
1138
        mm_items: MultiModalDataItems,
1139
        hf_processor_mm_kwargs: Mapping[str, object],
1140
        out_mm_kwargs: MultiModalKwargsItems,
1141
    ) -> Sequence[PromptUpdate]:
1142
1143
        """
        Given the original multi-modal items for this modality
1144
        and HF-processed data, output the updates to perform.
1145

1146
1147
1148
1149
1150
1151
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
1152
1153
        in order to construct
        [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
1154
        for each multi-modal item.
1155
1156
        """
        raise NotImplementedError
1157

1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    def _bind_and_group_updates(
        self,
        prompt_updates: Sequence[PromptUpdate],
        mm_item_counts: Mapping[str, int],
    ) -> MultiModalPromptUpdates:
        tokenizer = self.info.get_tokenizer()

        return {
            modality:
            [[update.resolve(tokenizer, item_idx) for update in updates]
             for item_idx in range(mm_item_counts.get(modality, 0))]
            for modality, updates in full_groupby_modality(prompt_updates)
        }

    def _get_mm_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> MultiModalPromptUpdates:
        unbound_prompt_updates = self._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )

        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates,
            mm_items.get_all_counts(),
        )

        for modality, prompt_updates in mm_prompt_updates.items():
            for item_idx, item_prompt_updates in enumerate(prompt_updates):
                if len(item_prompt_updates) > 1:
                    logger.warning_once(
                        "Detected %d prompt updates for `mm_items[%r][%s]`. "
                        "Multiple prompt updates per item is now "
                        "deprecated and may be removed in v0.13. "
                        "Instead, please specify dynamic update targets "
                        "in the same prompt update definition by passing "
                        "a function to `PromptUpdate.target`.",
                        len(prompt_updates),
                        modality,
                        item_idx,
                    )

        return mm_prompt_updates

1206
    def _find_mm_placeholders(
1207
1208
        self,
        new_token_ids: list[int],
1209
        mm_prompt_updates: MultiModalPromptUpdates,
1210
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1211
        return find_mm_placeholders(new_token_ids, mm_prompt_updates)
1212

1213
    def _get_hf_mm_data(
1214
        self,
1215
        mm_items: MultiModalDataItems,
1216
1217
1218
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1219

1220
1221
1222
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1223

1224
1225
        return processor_data, passthrough_data

1226
1227
1228
    def _call_hf_processor(
        self,
        prompt: str,
1229
1230
1231
1232
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1233
        tok_kwargs: Mapping[str, object],
1234
    ) -> "BatchFeature":
1235
1236
1237
1238
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1239
1240
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1241
            dict(text=prompt, **mm_data),
1242
            dict(**mm_kwargs, **tok_kwargs),
1243
1244
        )

1245
    def _hf_processor_applies_updates(
1246
1247
1248
1249
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1250
        tokenization_kwargs: Mapping[str, object],
1251
1252
    ) -> bool:
        """
1253
        Return whether the HF processor applies prompt updates.
1254

1255
1256
        For most HF processors, this should be `True` when multi-modal
        data items are passed, but `False` when multi-modal embeddings
1257
1258
1259
1260
1261
1262
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1263
    def _apply_hf_processor_text_mm(
1264
        self,
1265
        prompt_text: str,
1266
        mm_items: MultiModalDataItems,
1267
        hf_processor_mm_kwargs: Mapping[str, object],
1268
        tokenization_kwargs: Mapping[str, object],
1269
    ) -> tuple[list[int], "BatchFeature", bool]:
1270
        """
1271
1272
        Apply the HF processor on the prompt text and multi-modal data
        together.
1273

1274
        In addition, return whether prompt updates have been applied.
1275
1276
1277
1278
1279
1280
1281
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
1282
            tok_kwargs=tokenization_kwargs,
1283
1284
        )
        processed_data.update(passthrough_data)
1285

1286
        prompt_ids, = processed_data.pop("input_ids").tolist()
1287

1288
        is_update_applied = self._hf_processor_applies_updates(
1289
1290
1291
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1292
            tokenization_kwargs=tokenization_kwargs,
1293
1294
        )

1295
        return prompt_ids, processed_data, is_update_applied
1296

1297
    def _apply_hf_processor_text_only(
1298
1299
1300
1301
        self,
        prompt_text: str,
        tokenization_kwargs: Mapping[str, object],
    ) -> list[int]:
1302
        """
1303
        Apply the HF processor on the prompt text only.
1304

1305
1306
1307
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1308
        """
1309
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1310
1311
1312
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
1313
            tokenization_kwargs=tokenization_kwargs,
1314
1315
        )

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
1328
1329
1330
        with the output of
        [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
        on the
1331
1332
1333
1334
1335
1336
1337
1338
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1339
        tokenization_kwargs: Mapping[str, object],
1340
    ) -> "BatchFeature":
1341
1342
1343
1344
1345
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
1346
1347
        [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
        to go along with the multi-modal data.
1348
1349
1350
        """
        mm_counts = mm_items.get_all_counts()

1351
        _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
1352
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1353
1354
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1355
            tokenization_kwargs=tokenization_kwargs,
1356
1357
        )

1358
        return mm_processed_data
1359
1360
1361
1362
1363
1364

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1365
        tokenization_kwargs: Mapping[str, object],
1366
        *,
1367
        enable_hf_prompt_update: bool,
1368
    ) -> tuple[list[int], "BatchFeature", bool]:
1369
1370
1371
        """
        Apply the HF processor on the prompt text and multi-modal data.

1372
        In addition, return whether prompt updates have been applied
1373
        (for most HF processors, this should be `True`).
1374

1375
        Note:
1376
            If `enable_hf_prompt_update=False`, we use HF processor
1377
            to perform prompt updates if available; HF processor requires
1378
            that the prompt corresponds to multi-modal items.
1379
1380
        """
        if isinstance(prompt, str):
1381
            if enable_hf_prompt_update:
1382
1383
1384
1385
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1386
                    tokenization_kwargs=tokenization_kwargs,
1387
1388
                )

1389
1390
            prompt_ids = self._apply_hf_processor_text_only(
                prompt, tokenization_kwargs)
1391
1392
1393
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1394
        mm_processed_data = self._apply_hf_processor_mm_only(
1395
            mm_items=mm_items,
1396
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1397
            tokenization_kwargs=tokenization_kwargs,
1398
1399
        )

1400
        return prompt_ids, mm_processed_data, False
1401

1402
1403
1404
1405
    def _get_cache_missing_items(
        self,
        cache: ProcessingCache,
        mm_data_items: MultiModalDataItems,
1406
1407
1408
1409
1410
1411
        mm_hashes: MultiModalHashes,
    ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
        mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
            modality: [(h if (v := cache.get(h)) is None else v)
                       for h in hashes]
            for modality, hashes in mm_hashes.items()
1412
1413
1414
1415
        }

        mm_missing_idxs = {
            modality: [
1416
1417
                idx for idx, item_or_hash in enumerate(items_or_hashes)
                if isinstance(item_or_hash, str)
1418
            ]
1419
            for modality, items_or_hashes in mm_cache_items_or_hashes.items()
1420
1421
1422
1423
1424
1425
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }

1426
        return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
1427
1428

    def _hash_mm_items(
1429
1430
1431
1432
1433
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> MultiModalHashes:
1434
1435
1436
1437
1438
1439
1440
        """Create MM hashes to be returned (only used in V1)."""
        model_id = self.info.model_id

        return {
            modality: [
                MultiModalHasher.hash_kwargs(model_id=model_id,
                                             **{modality: item},
1441
1442
                                             **hf_processor_mm_kwargs,
                                             **tokenization_kwargs)
1443
1444
1445
1446
1447
1448
1449
1450
                for item in items
            ]
            for modality, items in mm_items.items()
        }

    def _merge_mm_kwargs(
        self,
        cache: ProcessingCache,
1451
        mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
1452
        mm_missing_kwargs: MultiModalKwargsItems,
1453
    ) -> MultiModalKwargsItems:
1454
        mm_missing_next_idx = defaultdict[str, int](lambda: 0)
1455

1456
1457
1458
1459
        merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
        for modality, items_or_hashes in mm_cache_items_or_hashes.items():
            for item_or_hash in items_or_hashes:
                if isinstance(item_or_hash, str):
1460
1461
                    kw_item = mm_missing_kwargs[modality][
                        mm_missing_next_idx[modality]]
1462
                    cache.put(item_or_hash, kw_item)
1463
1464
                    mm_missing_next_idx[modality] += 1
                else:
1465
                    kw_item = item_or_hash
1466

1467
                merged_items[modality].append(kw_item)
1468

1469
        return MultiModalKwargsItems(merged_items)
1470
1471
1472
1473
1474
1475

    def _apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1476
        tokenization_kwargs: Mapping[str, object],
1477
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1478
1479
        (
            prompt_ids,
1480
            mm_processed_data,
1481
1482
1483
1484
1485
            is_update_applied,
        ) = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1486
            tokenization_kwargs=tokenization_kwargs,
1487
1488
1489
            enable_hf_prompt_update=True,
        )

1490
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
1491
1492
1493
1494
1495
            mm_processed_data,
            self._get_mm_fields_config(mm_processed_data,
                                       hf_processor_mm_kwargs),
        )

1496
1497
        mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
                                        tokenization_kwargs)
1498

1499
        mm_prompt_updates = self._get_mm_prompt_updates(
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
            mm_data_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
            hashes=mm_hashes,
            prompt_updates=mm_prompt_updates,
        )

        return prompt_ids, mm_info, is_update_applied
1512

1513
1514
    def _cached_apply_hf_processor(
        self,
1515
        prompt: Union[str, list[int]],
1516
1517
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1518
        tokenization_kwargs: Mapping[str, object],
1519
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1520
1521
1522
1523
1524
1525
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache

1526
1527
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1528
            return self._apply_hf_processor(
1529
                prompt=prompt,
1530
                mm_data_items=mm_data_items,
1531
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1532
                tokenization_kwargs=tokenization_kwargs,
1533
1534
            )

1535
1536
        mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
                                        tokenization_kwargs)
1537
        (
1538
1539
            mm_cache_items_or_hashes,
            mm_missing_data_items,
1540
1541
1542
        ) = self._get_cache_missing_items(
            cache=cache,
            mm_data_items=mm_data_items,
1543
            mm_hashes=mm_hashes,
1544
        )
1545

1546
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1547
        # so we can't apply prompt updates until the new multimodal
1548
1549
1550
        # items are combined with the cached multimodal items
        (
            prompt_ids,
1551
            mm_missing_processed_data,
1552
            is_update_applied,
1553
        ) = self._apply_hf_processor_main(
1554
            prompt=prompt,
1555
            mm_items=mm_missing_data_items,
1556
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1557
            tokenization_kwargs=tokenization_kwargs,
1558
            enable_hf_prompt_update=False,
1559
1560
        )

1561
        mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
1562
1563
1564
1565
1566
            mm_missing_processed_data,
            self._get_mm_fields_config(mm_missing_processed_data,
                                       hf_processor_mm_kwargs),
        )

1567
        mm_kwargs = self._merge_mm_kwargs(
1568
            cache,
1569
            mm_cache_items_or_hashes=mm_cache_items_or_hashes,
1570
1571
            mm_missing_kwargs=mm_missing_kwargs,
        )
1572

1573
        mm_prompt_updates = self._get_mm_prompt_updates(
1574
1575
1576
1577
1578
1579
1580
            mm_data_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
1581
            hashes=mm_hashes,
1582
1583
            prompt_updates=mm_prompt_updates,
        )
1584

1585
        return prompt_ids, mm_info, is_update_applied
1586

1587
1588
1589
    def _apply_token_matches(
        self,
        prompt: list[int],
1590
1591
1592
1593
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
1594
1595
1596
1597

    def _apply_text_matches(
        self,
        prompt: str,
1598
1599
1600
1601
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
1602

1603
    def _apply_prompt_updates(
1604
1605
        self,
        token_ids: list[int],
1606
        mm_prompt_updates: MultiModalPromptUpdates,
1607
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1608
        tokenizer = self.info.get_tokenizer()
1609

1610
1611
1612
1613
        new_token_ids, match_result = self._apply_token_matches(
            token_ids,
            mm_prompt_updates,
        )
1614
1615
1616
1617
1618
1619
1620
1621
1622

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1623
1624
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1625
        if all(
1626
1627
1628
1629
1630
1631
1632
                all(update_idx is not None for update_idx in update_idxs)
                for update_idxs in match_result.values()):
            new_text = decode_tokens(tokenizer, new_token_ids)
        else:
            new_text, match_result = self._apply_text_matches(
                decode_tokens(tokenizer, token_ids),
                mm_prompt_updates,
1633
1634
            )

1635
1636
1637
1638
            new_token_ids = encode_tokens(
                tokenizer,
                new_text,
                add_special_tokens=False,
1639
1640
            )

1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        matched_updates = defaultdict[
            str, list[Sequence[ResolvedPromptUpdate]]](list)
        for modality, update_idxs in match_result.items():
            for item_idx, update_idx in enumerate(update_idxs):
                assert update_idx is not None, (
                    "Failed to apply prompt replacement for "
                    f"mm_items[{modality!r}][{item_idx}]")

                matched_updates[modality].append(
                    [mm_prompt_updates[modality][item_idx][update_idx]])
1651
1652

        placeholders = self._find_mm_placeholders(
1653
1654
            new_token_ids,
            dict(matched_updates),
1655
        )
1656

1657
        return new_token_ids, new_text, placeholders
1658

1659
1660
    def _validate_mm_kwargs(
        self,
1661
        mm_kwargs: MultiModalKwargsItems,
1662
1663
1664
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
1665
            items = mm_kwargs.get(modality, [])
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1679
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1680
        mm_item_counts: Mapping[str, int],
1681
    ) -> None:
1682
1683
1684
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1685
            if len(placeholders) != item_count:
1686
1687
1688
                # NOTE: If you are a model developer, this can also arise from
                # an inconsistency between `_call_hf_processor` and
                # `_get_mm_fields_config` implementations
1689
                raise RuntimeError(
1690
                    f"Expected there to be {item_count} prompt updates "
1691
                    f"corresponding to {item_count} {modality} items, but "
1692
                    f"instead found {len(placeholders)} prompt updates! "
1693
1694
1695
1696
                    "This is likely because you forgot to include input "
                    "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
                    "in the prompt. If the model has a chat template, make "
                    "sure you have applied it before calling `LLM.generate`.")
1697

1698
1699
1700
1701
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
1702
        mm_kwargs: MultiModalKwargsItems,
1703
        mm_prompt_updates: MultiModalPromptUpdates,
1704
1705
        is_update_applied: bool,
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1706
        mm_item_counts = mm_items.get_all_counts()
1707
1708
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1709
        if is_update_applied:
1710
1711
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
1712
                mm_prompt_updates,
1713
            )
1714
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1715

1716
            tokenizer = self.info.get_tokenizer()
1717
            prompt = decode_tokens(tokenizer, prompt_ids)
1718
1719
1720
        else:
            (
                prompt_ids,
1721
                prompt,
1722
                mm_placeholders,
1723
            ) = self._apply_prompt_updates(
1724
                prompt_ids,
1725
                mm_prompt_updates,
1726
            )
1727
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1728

1729
1730
1731
1732
1733
1734
1735
        return prompt_ids, prompt, mm_placeholders

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1736
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        mm_items = self._to_mm_items(mm_data)

1753
1754
1755
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

1756
1757
        (
            prompt_ids,
1758
            mm_info,
1759
1760
1761
1762
1763
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
1764
            tokenization_kwargs=tokenization_kwargs,
1765
1766
        )

1767
        # NOTE: tokenization_kwargs are not required to init processor
1768
1769
1770
        prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
            mm_items=mm_items,
            prompt_ids=prompt_ids,
1771
1772
            mm_kwargs=mm_info.kwargs,
            mm_prompt_updates=mm_info.prompt_updates,
1773
1774
1775
            is_update_applied=is_update_applied,
        )

1776
1777
1778
1779
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1780

1781
        return MultiModalInputs(
1782
            type="multimodal",
1783
            prompt=prompt,
1784
            prompt_token_ids=prompt_ids,
1785
1786
            mm_kwargs=mm_info.kwargs,
            mm_hashes=mm_info.hashes,
1787
            mm_placeholders=mm_placeholder_ranges,
1788
        )
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1799
        """
1800
        Create input prompt for the encoder. HF processor will be applied on
1801
1802
        this prompt during profiling and generation.
        """
1803
1804
        raise NotImplementedError

1805
1806
1807
1808
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1809
1810
1811
1812
1813
1814
1815
1816
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1817
    def _get_enc_dec_inputs(
1818
1819
1820
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
1821
1822
        encoder_inputs: MultiModalInputs,
    ):
1823
        tokenizer = self.info.get_tokenizer()
1824
1825
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1826
            decoder_prompt_ids = encode_tokens(tokenizer,
1827
                                               decoder_prompt,
1828
1829
                                               add_special_tokens=False)
        else:
1830
1831
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs
1842
1843
1844
1845
1846
1847

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
1848
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
1862
            tokenization_kwargs,
1863
1864
1865
1866
1867
1868
1869
        )

        return self._get_enc_dec_inputs(
            prompt=prompt,
            mm_data=mm_data,
            encoder_inputs=encoder_inputs,
        )