processing.py 59.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import sys
5
from abc import ABC, abstractmethod
6
from collections import defaultdict
7
8
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
9
from dataclasses import dataclass, field
10
from enum import Enum
11
from functools import lru_cache
12
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
13
                    TypeVar, Union, cast)
14

15
import regex as re
16
import torch
17
from typing_extensions import assert_never
18

19
from vllm.inputs import InputProcessingContext
20
from vllm.jsontree import json_map_leaves, json_reduce_leaves
21
from vllm.logger import init_logger
22
23
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
24
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
25

26
from .hasher import MultiModalHasher
27
28
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
29
                     MultiModalKwargsItem, NestedTensors, PlaceholderRange)
30
31
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
32
33

if TYPE_CHECKING:
34
35
36
37
    from transformers.configuration_utils import PretrainedConfig
    from transformers.feature_extraction_utils import BatchFeature
    from transformers.processing_utils import ProcessorMixin

38
    from .profiling import BaseDummyInputsBuilder
39

40
logger = init_logger(__name__)
41
42

_S = TypeVar("_S", str, list[int])
43
44
45

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
46

47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
    get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]


class PromptIndexTargets:

    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: 0)

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
        ) -> Optional[int]:
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
84
85
86
                    prefix = encode_tokens(tokenizer,
                                           prefix,
                                           add_special_tokens=False)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
        return PromptIndex(lambda tok, prompt: len(prompt))


PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""


109
@dataclass
110
class PromptUpdateDetails(Generic[_S]):
111
    """Details about the token sequence or text that are part of the update."""
112

113
    full: _S
114
    """The full content."""
115

116
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
117
    """
118
119
120
    Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
    return a boolean mask of shape `(len(full),)` indicating which positions
    of `full` to assign embeddings to.
121
122
123
124

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
125
    [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
126
127
128
    """

    @staticmethod
129
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":

        def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
            embed_token_ids = encode_tokens(full.tokenizer, embed_text)

            return torch.isin(
                torch.tensor(full.token_ids),
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
        return PromptUpdateDetails(
            full=seq,
            is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
        )
157
158


159
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
160
"""
161
The token sequence or text that are part of the update.
162

163
If only part of the content corresponds to feature placeholders, you can
164
165
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
166
"""
167

168
169
170
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
                            PromptUpdateInfo]
"""
171
172
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
173
174
175
176
177
178
179
180
181
182
183
184
185
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
186
class PromptUpdate(ABC):
187
188
189
190
191
192
193
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

194
    target: PromptTarget
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
        return BoundPromptUpdate(
            _origin=self,
            tokenizer=tokenizer,
        )

215

216
@dataclass
217
218
219
220
221
222
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    For each image, insert a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder after the ``<s>`` token:

    ```python
    PromptInsertion(
        modality="image",
        target="<s>",
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the start of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.start(),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens after a prefix ``Images:``:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.prefix("Images:"),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the end of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.end(),
        insertion="<image>" * image_feature_size,
    )
    ```
263
264
265
266
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
267
268
269
270
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to insert right after
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
287
288
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
289
290
291

    Example:

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    For each image, replace one ``<image>`` input placeholder in the prompt
    with a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement="<image>" * image_feature_size,
    )
    ```

    As above, but further pad the feature placeholders with ``<image_bos>``
    and `<image_eos>``, which are not supposed to be passed to the vision
    encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement=PromptUpdateDetails(
            full="".join([
                "<image_bos>",
                "<image>" * image_feature_size,
                "<image_eos>",
            ]),
            features="<image>" * image_feature_size,
        ),
    )
    ```

    To avoid unnecessary tokenization during prompt replacement,
    we recommended passing token sequences instead of text:

    ```python
    PromptReplacement(
        modality="image",
        target=[image_token_id],
        replacement=PromptUpdateDetails(
            full=([image_bos_id] + [image_token_id] * image_feature_size
                    + [image_eos_id]),
            features=[image_token_id] * image_feature_size,
        ),
    )
    ```
337
338
    """

339
    replacement: PromptUpdateContent = field(repr=False)
340
    """
341
342
343
344
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to replace
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
345

346
347
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
348
349
    """

350
351
352
353
354
355
356
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
357
358


359
360
361
362
363
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
364
    add_special_tokens: Optional[bool] = None,
365
) -> list[int]:
366
367
368
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
369
370


371
372
373
374
375
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
376
    skip_special_tokens: Optional[bool] = None,
377
) -> str:
378
379
380
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
381
382
383
384
385


class _HasModalityAttr(Protocol):
    modality: str

386

387
class _HasModalityProp(Protocol):
388

389
390
391
392
393
394
395
396
397
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
398
399
    """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
    based on modality."""
400
401
402
403
404
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
405
    """
406
407
    A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
    to a tokenizer to automatically
408
409
    convert between token sequence and text representations.
    """
410
411
    tokenizer: AnyTokenizer = field(repr=False)

412
413
414
    _text: Optional[str]
    _token_ids: Optional[list[int]]

415
    @staticmethod
416
417
    def from_seq(
        tokenizer: AnyTokenizer,
418
        seq: PromptSeq,
419
    ) -> "_BoundPromptSequence":
420
421
422
423
424
425
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
443
444
445
            self._token_ids = _cached_encode(self.tokenizer,
                                             self._text,
                                             add_special_tokens=False)
446
447
448
449

        return self._token_ids


450
@dataclass
451
class _BoundPromptContent:
452
    full: _BoundPromptSequence
453
    is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
454
455


456
@dataclass
457
class BoundPromptUpdate:
458
    """
459
460
461
462
463
    A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound
    to a tokenizer to automatically convert
    [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
    [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
    between token sequence and text representations.
464
    """
465
    _origin: PromptUpdate
466
    tokenizer: AnyTokenizer = field(repr=False)
467

468
    def __post_init__(self) -> None:
469
470
471
472
473
        self._content_cache = dict[int, _BoundPromptContent]()

    @property
    def modality(self) -> str:
        return self._origin.modality
474
475

    @property
476
    def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
477
        """The token sequence (or text) to update."""
478
479
480
481
482
483
        target = self._origin.target

        if isinstance(target, PromptIndex):
            return target

        return _BoundPromptSequence.from_seq(self.tokenizer, target)
484

485
486
487
488
489
490
491
492
493
494
495
    @property
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        return self._origin.content

    @property
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        return self._origin.mode

    def get_content(self, item_idx: int) -> _BoundPromptContent:
496
        """
497
498
        Given the index of the processed item within
        [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
499
        output the token sequence (or text) to update.
500
        """
501
502
        content = self.content
        if callable(content):
503
            cache_key = item_idx
504
505
            if cache_key in self._content_cache:
                return self._content_cache[cache_key]
506

507
            content = content(item_idx)
508
509
510
        else:
            cache_key = None

511
512
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)
513
514

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
515
516
                                                   content.full)
        bound_content = _BoundPromptContent(full=bound_full,
517
                                            is_embed=content.is_embed)
518
519

        if cache_key is not None:
520
            self._content_cache[cache_key] = bound_content
521

522
        return bound_content
523
524


525
526
527
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
528
529


530
531
532
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
533
) -> Generator[_TokenMatch]:
534
    """
535
    Yield each occurrence of `match_ids` in `token_ids`.
536
537
538
539

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
540
    match_len = len(match_ids)
541

542
543
    if match_len == 0:
        return
544

545
546
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
547
        end_idx = start_idx + match_len
548

549
550
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
551
552
553
554
555

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
556
557


558
559
560
561
562
563
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
564
565
    Replace each occurrence of `match_ids` in `token_ids`
    with `new_ids`.
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


585
@dataclass(repr=False)
586
class PromptTargetMatch(ABC):
587
    _origin: BoundPromptUpdate
588
589
590

    @property
    def modality(self) -> str:
591
        return self._origin.modality
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


608
@dataclass(repr=False)
609
class _PromptTargetIndexMatch(PromptTargetMatch):
610
611
612
613
614
615
616
617
618
619
620
    match_idx: int

    @property
    def start_idx(self) -> int:
        return self.match_idx

    @property
    def end_idx(self) -> int:
        return self.match_idx


621
@dataclass(repr=False)
622
class _PromptTargetTokenMatch(PromptTargetMatch):
623
624
625
626
627
628
629
630
631
632
633
634
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
635
class _PromptTargetTextMatch(PromptTargetMatch):
636
637
638
639
640
641
642
643
644
645
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

646

647
@dataclass
648
class PlaceholderFeaturesInfo:
649
    modality: str
650
    item_idx: int
651
    start_idx: int
652
    tokens: list[int]
653
    is_embed: Optional[torch.Tensor]
654
655
656

    @property
    def length(self) -> int:
657
        return len(self.tokens)
658
659

    def to_range(self) -> PlaceholderRange:
660
661
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
662
663
664
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
665
            is_embed=self.is_embed,
666
        )
667
668
669
670


def find_token_matches(
    prompt: list[int],
671
    prompt_updates: Sequence[BoundPromptUpdate],
672
) -> Sequence[PromptTargetMatch]:
673
    """Return each target of `prompt_updates` found in `prompt`."""
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTokenMatch(update, match)
            for match in iter_token_matches(prompt, target.token_ids)
        ]

690
    return [
691
        match for update in prompt_updates for match in get_matches(update)
692
693
694
695
696
    ]


def find_text_matches(
    prompt: str,
697
    prompt_updates: Sequence[BoundPromptUpdate],
698
) -> Sequence[PromptTargetMatch]:
699
    """Return each target of `prompt_updates` found in `prompt`."""
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

    def get_matches(update: BoundPromptUpdate):
        target = update.target

        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(update.tokenizer, prompt)
            if match_idx is None:
                return []

            return [_PromptTargetIndexMatch(update, match_idx)]

        return [
            _PromptTargetTextMatch(update, match)
            for match in re.finditer(re.escape(target.text), prompt)
        ]

716
    return [
717
        match for update in prompt_updates for match in get_matches(update)
718
719
720
721
    ]


def _resolve_matches(
722
    prompt: PromptSeq,
723
724
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
725
    """
726
    Resolve `mm_matches` to ensure that there are no overlapping matches,
727
    and sort them such that earlier matches take priority over later ones.
728
    """
729
730
    matches = [m for matches in mm_matches.values() for m in matches]

731
    seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
732

733
    for match in matches:
734
735
736
737
738
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
739

740
            seen_matches[idx] = match
741
742
743
744

    return sorted(matches, key=lambda x: x.start_idx)


745
def _apply_matches(
746
    prompt: _S,
747
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
748
    mm_item_counts: Mapping[str, int],
749
) -> list[_S]:
750
    """Apply the updates in `mm_matches` to `prompt`."""
751
    out_seqs = list[Union[str, list[int]]]()
752
    prev_end_idx = 0
753
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
754

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    for match in _resolve_matches(prompt, mm_matches):
        modality = match.modality

        item_start_idx = next_idx_by_modality[modality]
        max_item_count = mm_item_counts.get(modality, 0)
        if item_start_idx >= max_item_count:
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        origin = match._origin
        mode = origin.mode

        if mode == UpdateMode.INSERT:
            out_seqs.append(prompt[prev_end_idx:end_idx])
            num_inserts = max_item_count
        elif mode == UpdateMode.REPLACE:
            out_seqs.append(prompt[prev_end_idx:start_idx])
            num_inserts = max_item_count if start_idx == end_idx else 1
        else:
            assert_never(mode)
776

777
        item_end_idx = min(item_start_idx + num_inserts, max_item_count)
778

779
        for item_idx in range(item_start_idx, item_end_idx):
780
            content = origin.get_content(item_idx)
781
782
            insert_seq = (content.full.text if isinstance(prompt, str) else
                          content.full.token_ids)
783

784
            out_seqs.append(insert_seq)
785

786
787
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += item_end_idx - item_start_idx
788
789
790

    out_seqs.append(prompt[prev_end_idx:])

791
    return cast(list[_S], out_seqs)
792
793


794
def apply_token_matches(
795
    prompt: list[int],
796
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
797
    mm_item_counts: Mapping[str, int],
798
) -> list[int]:
799
    """Apply the updates in `mm_matches` to `prompt`."""
800
    if not mm_matches:
801
802
        return prompt

803
    token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
804
805

    return flatten_2d_lists(token_id_seqs)
806
807


808
def apply_text_matches(
809
    prompt: str,
810
    mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
811
    mm_item_counts: Mapping[str, int],
812
) -> str:
813
    """Apply the updates in `mm_matches` to `prompt`."""
814
    if not mm_matches:
815
        return prompt
816

817
    texts = _apply_matches(prompt, mm_matches, mm_item_counts)
818
819

    return "".join(texts)
820
821


822
def _iter_placeholders(
823
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
824
    prompt: list[int],
825
    mm_item_counts: Mapping[str, int],
826
) -> Iterable[PlaceholderFeaturesInfo]:
827
    """
828
    Yield each set of placeholder tokens found in `prompt`.
829
830
831

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
832
    appears earlier in `mm_prompt_updates` takes priority.
833

834
835
    Note that empty matches are ignored.
    """
836
    prompt_len = len(prompt)
837
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
838
839
840
841
842

    start_idx = 0
    while start_idx < prompt_len:
        found = False

843
        for modality, modality_updates in mm_prompt_updates.items():
844
845
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
846
                continue
847

848
849
850
851
852
            for update_info in modality_updates:
                content = update_info.get_content(item_idx)
                content_tokens_full = content.full.token_ids
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
853

854
                if content_len_full == 0 or end_idx_full > prompt_len:
855
856
                    continue

857
                if prompt[start_idx:end_idx_full] == content_tokens_full:
858
859
860
861
862
863
864
865
866
867
868
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
                        content_is_embed = content_is_embed(content.full)

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
869

870
                    # Exclude overlapping matches
871
                    start_idx = end_idx_full
872
873
874
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
875

876
877
            if found:
                break  # Go back to the outer while loop
878
879
880

        if not found:
            start_idx += 1
881
882


883
def find_mm_placeholders(
884
    mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
885
886
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
887
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
888
    it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
889
890
891
    return dict(full_groupby_modality(it))


892
893
894
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")


895
896
897
898
899
900
901
902
903
904
class ProcessingCacheOptionalItem(NamedTuple):
    key: str
    value: Optional[MultiModalKwargsItem]


class ProcessingCacheItem(NamedTuple):
    key: str
    value: MultiModalKwargsItem


905
906
class ProcessingCache:

907
908
    @staticmethod
    def get_lru_cache(
909
        capacity_gb: float,
910
        value_type: type[_V],
911
912
        *,
        debug: bool = False,
913
914
    ) -> LRUCache[str, _V]:

915
916
917
918
919
920
921
922
923
924
925
        def get_leaf_size(leaf: object) -> int:
            # MultiModalKwargs is not a subclass of dict
            if isinstance(leaf, MultiModalKwargs):
                return get_item_size(leaf.data)

            # MultiModalKwargsItem is not a subclass of dict
            if isinstance(leaf, MultiModalKwargsItem):
                leaf_data = {k: v.data for k, v in leaf.items()}
                return get_item_size(leaf_data)

            # sys.getsizeof doesn't work for tensors
926
            if isinstance(leaf, torch.Tensor):
927
                return leaf.nbytes
928
929
930

            return sys.getsizeof(leaf)

931
932
933
934
935
        def get_item_size(
            value: Union[MultiModalKwargs, MultiModalKwargsItem,
                         Mapping[str, NestedTensors]]
        ) -> int:
            size = json_reduce_leaves(
936
                lambda a, b: a + b,
937
938
939
940
941
942
                json_map_leaves(get_leaf_size, value),
            )

            if debug:
                logger.debug("Calculated size of %s to be %.2f GiB",
                             type(value), size / GiB_bytes)
943

944
945
946
947
948
949
950
951
952
953
            return size

        return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

    def __init__(
        self,
        capacity_gb: float,
        *,
        debug_cache_hit_ratio_steps: Optional[int] = None,
    ) -> None:
954
955
        super().__init__()

956
        self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
957
958
        self.debug_cache_hits = 0
        self.debug_cache_total = 0
959

960
961
962
963
964
        self._cache = self.get_lru_cache(
            capacity_gb,
            MultiModalKwargsItem,
            debug=bool(debug_cache_hit_ratio_steps),
        )
965
966
967
968
969
970

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

971
972
        total = self.debug_cache_total
        if total > 0 and total % steps == 0:
973
            logger.debug("ProcessingCache: hit_ratio = %.2f",
974
                         self.debug_cache_hits / total)
975
976
977
            logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
                         self._cache.currsize / GiB_bytes,
                         self._cache.maxsize / GiB_bytes)
978
979
980
981
982
983
984

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
985
    ) -> Optional[MultiModalKwargsItem]:
986
987
988
989
990
991
992
993
994
995
996
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

997
998
999
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
1000
1001
1002
1003
1004
1005
1006

        if self.debug_cache_hit_ratio_steps:
            if cache_key in self._cache:
                self.debug_cache_hits += 1

            self.debug_cache_total += 1

1007
1008
        return self._cache.get(cache_key)

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    def get_item(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
    ) -> ProcessingCacheOptionalItem:
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)

        return ProcessingCacheOptionalItem(
            key=cache_key,
            value=self._cache.get(cache_key),
        )

1025
1026
1027
1028
1029
1030
    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
1031
        output_kwargs: MultiModalKwargsItem,
1032
1033
1034
    ) -> None:
        """
        Put a processed multi-modal item into the cache
1035
1036
        according to its dependencies
        (see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
1037
        """
1038
1039
1040
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
1041
        self._cache[cache_key] = output_kwargs
1042

1043
1044
1045
    def put_item(self, item: ProcessingCacheItem) -> None:
        self._cache[item.key] = item.value

1046
1047
1048
1049
1050
    def reset(self) -> bool:
        self._cache.clear()

        return True

1051

1052
class BaseProcessingInfo:
1053
    """Base class to provide the information necessary for data processing."""
1054

1055
1056
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1057

1058
1059
1060
1061
1062
1063
1064
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1065
1066
        return self.ctx.tokenizer

1067
    def get_hf_config(self) -> "PretrainedConfig":
1068
1069
        return self.ctx.get_hf_config()

1070
    def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
1071
1072
1073
1074
1075
1076
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

            allowed_limits[modality] = (user_limit if supported_limit is None
                                        else min(user_limit, supported_limit))

        return allowed_limits

1103
1104

_I = TypeVar("_I", bound=BaseProcessingInfo)
1105

1106
1107
MultiModalHashes = dict[str, list[str]]
"""
1108
1109
A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
1110
1111
"""

1112
1113

class BaseMultiModalProcessor(ABC, Generic[_I]):
1114
    """
1115
    Abstract base class to process multi-modal inputs to be used in vLLM.
1116

1117
    Not to be confused with `transformers.ProcessorMixin`.
1118
1119
    """

1120
    def __init__(self,
1121
1122
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
1123
                 *,
1124
                 cache: Optional[ProcessingCache] = None) -> None:
1125
1126
        super().__init__()

1127
1128
        self.info = info
        self.dummy_inputs = dummy_inputs
1129
        self.cache = cache
1130

1131
1132
        self.data_parser = self._get_data_parser()

1133
    def __call__(
1134
        self,
1135
1136
        prompt: str,
        mm_data: MultiModalDataDict,
1137
        hf_processor_mm_kwargs: Mapping[str, object],
1138
    ) -> MultiModalInputs:
1139
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
1140

1141
1142
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1143
        Construct a parser to preprocess multi-modal data items
1144
1145
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1146
1147

        You can support additional modalities by creating a subclass
1148
1149
        of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
        that has additional subparsers.
1150
1151
1152
1153
        """
        return MultiModalDataParser()

    def _to_mm_items(
1154
1155
1156
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1157
        """
1158
1159
1160
1161
1162
        Normalize
        [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
        to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1163
        """
1164
        mm_items = self.data_parser.parse_mm_data(mm_data)
1165
1166
        supported_mm_limits = self.info.get_supported_mm_limits()
        allowed_mm_limits = self.info.get_allowed_mm_limits()
1167
1168

        for modality, items in mm_items.items():
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
            supported_limit = supported_mm_limits.get(modality, 0)
            allowed_limit = allowed_mm_limits.get(modality, 0)
            num_items = len(items)

            if supported_limit is not None and num_items > supported_limit:
                raise ValueError(
                    f"The model only supports at most {supported_limit} "
                    f"{modality} items, but you passed {num_items} "
                    f"{modality} items in the same prompt.")

            if num_items > allowed_limit:
1180
                raise ValueError(
1181
1182
1183
                    "You set or defaulted to "
                    f"'{json.dumps({modality: allowed_limit})}' in "
                    f"`--limit-mm-per-prompt`, but passed {num_items} "
1184
1185
1186
                    f"{modality} items in the same prompt.")

        return mm_items
1187

1188
1189
1190
    @abstractmethod
    def _get_mm_fields_config(
        self,
1191
        hf_inputs: "BatchFeature",
1192
1193
1194
1195
1196
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1197
    @abstractmethod
1198
    def _get_prompt_updates(
1199
        self,
1200
        mm_items: MultiModalDataItems,
1201
1202
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
1203
    ) -> Sequence[PromptUpdate]:
1204
1205
        """
        Given the original multi-modal items for this modality
1206
        and HF-processed data, output the updates to perform.
1207

1208
1209
1210
1211
1212
1213
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
1214
1215
        in order to construct
        [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
1216
        for each multi-modal item.
1217
1218
        """
        raise NotImplementedError
1219

1220
    def _find_mm_placeholders(
1221
        self,
1222
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1223
        new_token_ids: list[int],
1224
        mm_item_counts: Mapping[str, int],
1225
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1226
        return find_mm_placeholders(mm_prompt_updates, new_token_ids,
1227
                                    mm_item_counts)
1228

1229
    def _get_hf_mm_data(
1230
        self,
1231
        mm_items: MultiModalDataItems,
1232
1233
1234
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1235

1236
1237
1238
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1239

1240
1241
        return processor_data, passthrough_data

1242
1243
1244
    def _call_hf_processor(
        self,
        prompt: str,
1245
1246
1247
1248
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1249
    ) -> "BatchFeature":
1250
1251
1252
1253
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1254
1255
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1256
1257
            dict(text=prompt, **mm_data),
            mm_kwargs,
1258
1259
        )

1260
    def _hf_processor_applies_updates(
1261
1262
1263
1264
1265
1266
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
1267
        Return whether the HF processor applies prompt updates.
1268

1269
1270
        For most HF processors, this should be `True` when multi-modal
        data items are passed, but `False` when multi-modal embeddings
1271
1272
1273
1274
1275
1276
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

1277
    def _apply_hf_processor_text_mm(
1278
        self,
1279
        prompt_text: str,
1280
        mm_items: MultiModalDataItems,
1281
        hf_processor_mm_kwargs: Mapping[str, object],
1282
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1283
        """
1284
1285
        Apply the HF processor on the prompt text and multi-modal data
        together.
1286

1287
        In addition, return whether prompt updates have been applied.
1288
1289
1290
1291
1292
1293
1294
1295
1296
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
1297

1298
        prompt_ids, = processed_data.pop("input_ids").tolist()
1299

1300
1301
1302
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
1303
        )
1304

1305
        is_update_applied = self._hf_processor_applies_updates(
1306
1307
1308
1309
1310
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1311
        return prompt_ids, mm_kwargs, is_update_applied
1312

1313
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
1314
        """
1315
        Apply the HF processor on the prompt text only.
1316

1317
1318
1319
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1320
        """
1321
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1322
1323
1324
1325
1326
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
1339
1340
1341
        with the output of
        [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
        on the
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
1356
1357
        [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
        to go along with the multi-modal data.
1358
1359
1360
        """
        mm_counts = mm_items.get_all_counts()

1361
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
1362
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
1375
        enable_hf_prompt_update: bool,
1376
    ) -> tuple[list[int], MultiModalKwargs, bool]:
1377
1378
1379
        """
        Apply the HF processor on the prompt text and multi-modal data.

1380
        In addition, return whether prompt updates have been applied
1381
        (for most HF processors, this should be `True`).
1382

1383
        Note:
1384
            If `enable_hf_prompt_update=False`, we use HF processor
1385
            to perform prompt updates if available; HF processor requires
1386
            that the prompt corresponds to multi-modal items.
1387
1388
        """
        if isinstance(prompt, str):
1389
            if enable_hf_prompt_update:
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1400
        mm_kwargs = self._apply_hf_processor_mm_only(
1401
            mm_items=mm_items,
1402
1403
1404
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

1405
        return prompt_ids, mm_kwargs, False
1406

1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
    def _get_cache_missing_items(
        self,
        cache: ProcessingCache,
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
            str, list[object]]]:
        model_id = self.info.model_id

        mm_cache_items = {
            modality: [
                cache.get_item(model_id, modality, item,
                               hf_processor_mm_kwargs) for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
            modality: [
                idx for idx, item in enumerate(cache_items)
                if item.value is None
            ]
            for modality, cache_items in mm_cache_items.items()
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }

        return mm_cache_items, mm_missing_data

    def _hash_mm_items(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalHashes:
        """Create MM hashes to be returned (only used in V1)."""
        model_id = self.info.model_id

        return {
            modality: [
                MultiModalHasher.hash_kwargs(model_id=model_id,
                                             **{modality: item},
                                             **hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_items.items()
        }

    def _merge_mm_kwargs(
        self,
        cache: ProcessingCache,
        mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
        mm_missing_data: dict[str, list[object]],
        mm_missing_kwargs: MultiModalKwargs,
    ) -> dict[str, list[ProcessingCacheItem]]:
        mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}

        merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
        for modality, cache_items in mm_cache_items.items():
            for cache_item in cache_items:
                if cache_item.value is None:
                    kw_item = mm_missing_kwargs.get_item(
                        modality,
                        mm_missing_next_idx[modality],
                    )
                    cache_item_new = ProcessingCacheItem(
                        key=cache_item.key,
                        value=kw_item,
                    )

                    cache.put_item(cache_item_new)
                    mm_missing_next_idx[modality] += 1
                else:
                    cache_item_new = ProcessingCacheItem(
                        key=cache_item.key,
                        value=cache_item.value,
                    )

                merged_items[modality].append(cache_item_new)

        return dict(merged_items)

    def _apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        return_mm_hashes: bool,
    ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
        (
            prompt_ids,
            mm_kwargs,
            is_update_applied,
        ) = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            enable_hf_prompt_update=True,
        )

        mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs)
                     if return_mm_hashes else None)

        return prompt_ids, mm_kwargs, mm_hashes, is_update_applied

1514
1515
    def _cached_apply_hf_processor(
        self,
1516
        prompt: Union[str, list[int]],
1517
1518
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1519
1520
1521
        *,
        return_mm_hashes: bool,
    ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
1522
1523
1524
1525
1526
1527
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache

1528
1529
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1530
            return self._apply_hf_processor(
1531
                prompt=prompt,
1532
                mm_data_items=mm_data_items,
1533
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1534
                return_mm_hashes=return_mm_hashes,
1535
1536
            )

1537
1538
1539
1540
1541
1542
1543
1544
        (
            mm_cache_items,
            mm_missing_data,
        ) = self._get_cache_missing_items(
            cache=cache,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )
1545

1546
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1547
        # so we can't apply prompt updates until the new multimodal
1548
1549
1550
1551
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
1552
            is_update_applied,
1553
        ) = self._apply_hf_processor_main(
1554
            prompt=prompt,
1555
            mm_items=self._to_mm_items(mm_missing_data),
1556
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1557
            enable_hf_prompt_update=False,
1558
1559
        )

1560
1561
1562
1563
1564
1565
        mm_cache_items_merged = self._merge_mm_kwargs(
            cache,
            mm_cache_items=mm_cache_items,
            mm_missing_data=mm_missing_data,
            mm_missing_kwargs=mm_missing_kwargs,
        )
1566

1567
1568
1569
1570
        mm_kwargs = MultiModalKwargs.from_items([
            item.value for cache_items in mm_cache_items_merged.values()
            for item in cache_items
        ])
1571

1572
1573
1574
1575
        mm_hashes = {
            modality: [item.key for item in cache_items]
            for modality, cache_items in mm_cache_items_merged.items()
        } if return_mm_hashes else None
1576

1577
        return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
1578

1579
    def _bind_and_group_updates(
1580
        self,
1581
1582
        prompt_updates: Sequence[PromptUpdate],
    ) -> dict[str, Sequence[BoundPromptUpdate]]:
1583
        tokenizer = self.info.get_tokenizer()
1584

1585
        it = (update.bind(tokenizer) for update in prompt_updates)
1586
        return dict(full_groupby_modality(it))
1587

1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        return apply_token_matches(prompt, mm_matches, mm_item_counts)

    def _apply_text_matches(
        self,
        prompt: str,
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> str:
        return apply_text_matches(prompt, mm_matches, mm_item_counts)

1604
    def _apply_prompt_updates(
1605
1606
        self,
        token_ids: list[int],
1607
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
1608
        mm_item_counts: Mapping[str, int],
1609
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1610
        tokenizer = self.info.get_tokenizer()
1611

1612
        mm_token_matches = {
1613
1614
            modality: find_token_matches(token_ids, updates)
            for modality, updates in mm_prompt_updates.items()
1615
        }
1616
1617
        mm_match_counts = {
            modality: len(matches)
1618
            for modality, matches in mm_token_matches.items()
1619
        }
1620
1621
1622
1623
1624
1625
1626
1627
1628

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1629
1630
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1631
        if all(
1632
1633
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1634
        ):  # yapf: disable
1635
            token_ids = self._apply_token_matches(
1636
                token_ids,
1637
                mm_token_matches,
1638
                mm_item_counts,
1639
1640
            )

1641
            text = decode_tokens(tokenizer, token_ids)
1642
1643
            matched_updates = {
                modality: [match._origin for match in token_matches]
1644
1645
                for modality, token_matches in mm_token_matches.items()
            }
1646
        else:
1647
            text = decode_tokens(tokenizer, token_ids)
1648

1649
            mm_text_matches = {
1650
1651
                modality: find_text_matches(text, updates)
                for modality, updates in mm_prompt_updates.items()
1652
            }
1653
            text = self._apply_text_matches(
1654
                text,
1655
                mm_text_matches,
1656
                mm_item_counts,
1657
1658
            )

1659
1660
1661
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1662
1663
            matched_updates = {
                modality: [match._origin for match in token_matches]
1664
1665
1666
1667
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
1668
            matched_updates,
1669
1670
1671
            token_ids,
            mm_item_counts,
        )
1672
1673

        return token_ids, text, placeholders
1674

1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1698
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1699
        mm_item_counts: Mapping[str, int],
1700
    ) -> None:
1701
1702
1703
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1704
            if len(placeholders) != item_count:
1705
1706
1707
                # NOTE: If you are a model developer, this can also arise from
                # an inconsistency between `_call_hf_processor` and
                # `_get_mm_fields_config` implementations
1708
                raise RuntimeError(
1709
                    f"Expected there to be {item_count} prompt updates "
1710
                    f"corresponding to {item_count} {modality} items, but "
1711
                    f"instead found {len(placeholders)} prompt updates! "
1712
1713
1714
1715
                    "This is likely because you forgot to include input "
                    "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
                    "in the prompt. If the model has a chat template, make "
                    "sure you have applied it before calling `LLM.generate`.")
1716

1717
1718
1719
1720
1721
1722
1723
1724
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        prompt_ids: list[int],
        mm_kwargs: MultiModalKwargs,
        is_update_applied: bool,
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1725
        unbound_prompt_updates = self._get_prompt_updates(
1726
1727
1728
1729
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1730
1731
        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates)
1732

1733
        mm_item_counts = mm_items.get_all_counts()
1734
1735
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1736
        if is_update_applied:
1737
            mm_placeholders = self._find_mm_placeholders(
1738
                mm_prompt_updates,
1739
                prompt_ids,
1740
1741
                mm_item_counts,
            )
1742
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1743

1744
            tokenizer = self.info.get_tokenizer()
1745
            prompt = decode_tokens(tokenizer, prompt_ids)
1746
1747
1748
        else:
            (
                prompt_ids,
1749
                prompt,
1750
                mm_placeholders,
1751
            ) = self._apply_prompt_updates(
1752
                prompt_ids,
1753
                mm_prompt_updates,
1754
                mm_item_counts,
1755
            )
1756
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1757

1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
        return prompt_ids, prompt, mm_placeholders

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        return_mm_hashes: bool = False,
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        mm_items = self._to_mm_items(mm_data)

        (
            prompt_ids,
            mm_kwargs,
1785
            mm_hashes,
1786
1787
1788
1789
1790
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
1791
            return_mm_hashes=return_mm_hashes,
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
        )

        prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            prompt_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
            is_update_applied=is_update_applied,
        )

1802
1803
1804
1805
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1806

1807
        return MultiModalInputs(
1808
            type="multimodal",
1809
            prompt=prompt,
1810
            prompt_token_ids=prompt_ids,
1811
            mm_kwargs=mm_kwargs,
1812
            mm_hashes=mm_hashes,
1813
            mm_placeholders=mm_placeholder_ranges,
1814
        )
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1825
        """
1826
        Create input prompt for the encoder. HF processor will be applied on
1827
1828
        this prompt during profiling and generation.
        """
1829
1830
        raise NotImplementedError

1831
1832
1833
1834
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

1835
1836
1837
1838
1839
1840
1841
1842
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1843
    def _get_enc_dec_inputs(
1844
1845
1846
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
1847
1848
        encoder_inputs: MultiModalInputs,
    ):
1849
        tokenizer = self.info.get_tokenizer()
1850
1851
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1852
            decoder_prompt_ids = encode_tokens(tokenizer,
1853
                                               decoder_prompt,
1854
1855
                                               add_special_tokens=False)
        else:
1856
1857
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        return_mm_hashes: bool = False,
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
            return_mm_hashes,
        )

        return self._get_enc_dec_inputs(
            prompt=prompt,
            mm_data=mm_data,
            encoder_inputs=encoder_inputs,
        )