processing.py 69.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import time
4
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
7
from dataclasses import dataclass, field, replace
8
from enum import Enum
9
from functools import lru_cache
10
11
12
13
14
15
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    NamedTuple,
    Protocol,
16
    TypeAlias,
17
18
19
    cast,
    overload,
)
20

21
import regex as re
22
import torch
23
from typing_extensions import TypeVar, assert_never
24

25
from vllm.logger import init_logger
26
from vllm.transformers_utils.processor import cached_processor_from_config
27
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
28
29
from vllm.utils import flatten_2d_lists, full_groupby
from vllm.utils.func import get_allowed_kwarg_only_overrides
30
from vllm.utils.jsontree import JSONTree, json_map_leaves
31

32
from .hasher import MultiModalHasher
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from .inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalKwargsOptionalItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from .parse import (
    DictEmbeddingItems,
    EmbeddingItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
50
51

if TYPE_CHECKING:
52
53
54
55
    from transformers.configuration_utils import PretrainedConfig
    from transformers.feature_extraction_utils import BatchFeature
    from transformers.processing_utils import ProcessorMixin

56
57
    from vllm.config import ModelConfig

58
    from .cache import BaseMultiModalProcessorCache
59
    from .profiling import BaseDummyInputsBuilder
60
61
62
63
64
65
66
67
else:
    PretrainedConfig = object
    BatchFeature = object
    ProcessorMixin = object

    ModelConfig = object

    BaseMultiModalProcessorCache = object
68

69
logger = init_logger(__name__)
70
71

_S = TypeVar("_S", str, list[int])
72

73
PromptSeq: TypeAlias = str | list[int]
74
"""A token sequence (list of token IDs) or text."""
75

76

77
78
79
80
81
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
82
    add_special_tokens: bool | None = None,
83
) -> list[int]:
84
    return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens)
85
86
87
88
89
90
91


@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
92
    skip_special_tokens: bool | None = None,
93
) -> str:
94
95
96
    return decode_tokens(
        tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens
    )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
    if isinstance(seq, str):
        return seq

    return _cached_decode(tokenizer, tuple(seq))


def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
    if isinstance(seq, str):
        return _cached_encode(tokenizer, seq, add_special_tokens=False)

    return seq


113
114
115
116
117
118
class _GetMatchIndex(Protocol):
    def __call__(
        self,
        tokenizer: AnyTokenizer,
        prompt: PromptSeq,
        start_idx: int = 0,
119
    ) -> int | None: ...
120
121


122
123
124
@dataclass
class PromptIndex:
    """Resolves to an index in the prompt."""
125

126
    get_match_index: _GetMatchIndex
127
128
129
130
131
132
133
134
135
136


class PromptIndexTargets:
    @staticmethod
    def start() -> PromptIndex:
        """
        Resolves to the start of the prompt (before the first token).

        This results in a match even if the prompt is empty.
        """
137
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0)
138
139
140
141
142
143
144
145
146
147

    @staticmethod
    def prefix(seq: PromptSeq) -> PromptIndex:
        """
        Resolves to a location in the prompt after the given prefix.
        """

        def get_match_index(
            tokenizer: AnyTokenizer,
            prompt: PromptSeq,
148
            start_idx: int = 0,
149
        ) -> int | None:
150
151
152
            if start_idx != 0:
                return None

153
154
155
156
157
158
159
160
161
            prefix = seq

            if isinstance(prompt, str):
                if not isinstance(prefix, str):
                    # Make both `str`
                    prefix = decode_tokens(tokenizer, prefix)
            else:
                if isinstance(prefix, str):
                    # Make both `list[int]`
162
                    prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False)
163
164
165
166
167
168
169
170
171
172
173
174
175

            match_idx = len(prefix)
            return match_idx if prompt[:match_idx] == prefix else None

        return PromptIndex(get_match_index)

    @staticmethod
    def end() -> PromptIndex:
        """
        Resolves to the end of the prompt (after the last token).

        This results in a match even if the prompt is empty.
        """
176
        return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt))
177
178


179
UpdateTarget: TypeAlias = PromptSeq | PromptIndex
180
181
182
183
"""
The token sequence or text to update.
"""

184
PromptUpdateTarget: TypeAlias = Callable[[int], UpdateTarget] | UpdateTarget
185
186
187
188
189
190
191
192
193
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""

194

195
@dataclass
196
class PromptUpdateDetails(Generic[_S]):
197
    """Details about the token sequence or text that are part of the update."""
198

199
    full: _S
200
    """The full content."""
201

202
    is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None
203
    """
204
205
206
    Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
    return a boolean mask of shape `(len(full),)` indicating which positions
    of `full` to assign embeddings to.
207
208
209
210

    `None` (default) means to assign embeddings to all positions of `full`.

    The embeddings are obtained by calling
211
    [`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
212
213
214
    """

    @staticmethod
215
    def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
216
217
218
219
220
221
222
        return PromptUpdateDetails(full=seq)

    @staticmethod
    def select_text(
        seq: _S,
        embed_text: str,
    ) -> "PromptUpdateDetails[_S]":
223
224
225
        def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
            embed_token_ids = encode_tokens(tokenizer, embed_text)
            token_ids = _seq2tokens(tokenizer, full)
226
227

            return torch.isin(
228
                torch.tensor(token_ids),
229
230
231
232
233
234
235
236
237
238
                torch.tensor(embed_token_ids),
            )

        return PromptUpdateDetails(full=seq, is_embed=is_embed)

    @staticmethod
    def select_token_id(
        seq: _S,
        embed_token_id: int,
    ) -> "PromptUpdateDetails[_S]":
239
240
241
242
243
244
        def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
            token_ids = _seq2tokens(tokenizer, full)

            return torch.tensor(token_ids) == embed_token_id

        return PromptUpdateDetails(full=seq, is_embed=is_embed)
245
246


247
PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails
248
"""
249
The token sequence or text that are part of the update.
250

251
If only part of the content corresponds to feature placeholders, you can
252
253
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
254
"""
255

256
PromptUpdateContent: TypeAlias = Callable[[int], PromptUpdateInfo] | PromptUpdateInfo
257
"""
258
259
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
260
261
262
263
264
265
266
267
268
269
270
271
272
output the corresponding token sequence (or text).

For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""


class UpdateMode(str, Enum):
    INSERT = "insert"
    REPLACE = "replace"


@dataclass
273
class PromptUpdate(ABC):
274
275
276
277
278
279
280
    """
    Defines how to update a prompt with placeholder tokens.
    """

    modality: str
    """The modality for which the update is made."""

281
    target: PromptUpdateTarget
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    """The token sequence (or text) to update."""

    @property
    @abstractmethod
    def content(self) -> PromptUpdateContent:
        """The placeholder tokens that are part of the update."""
        raise NotImplementedError

    @property
    @abstractmethod
    def mode(self) -> UpdateMode:
        """Defines how to update the prompt."""
        raise NotImplementedError

296
    def _resolve_target(self, item_idx: int) -> UpdateTarget:
297
298
299
300
        target = self.target
        if callable(target):
            target = target(item_idx)

301
        return target
302

303
    def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
304
305
306
307
308
309
310
        content = self.content
        if callable(content):
            content = content(item_idx)

        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

311
        return content
312

313
    def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
314
315
316
317
318
319
320
321
322
        """
        Given the index of the processed item within
        [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
        output a copy of this object with its lazy attributes resolved.
        """
        return ResolvedPromptUpdate(
            modality=self.modality,
            item_idx=item_idx,
            mode=self.mode,
323
324
            target=self._resolve_target(item_idx),
            content=self._resolve_content(item_idx),
325
326
        )

327

328
@dataclass
329
330
331
332
333
334
class PromptInsertion(PromptUpdate):
    """
    Defines how to insert placeholder tokens into a prompt.

    Example:

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    For each image, insert a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder after the ``<s>`` token:

    ```python
    PromptInsertion(
        modality="image",
        target="<s>",
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the start of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.start(),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens after a prefix ``Images:``:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.prefix("Images:"),
        insertion="<image>" * image_feature_size,
    )
    ```

    Insert these tokens at the end of the prompt:

    ```python
    PromptInsertion(
        modality="image",
        target=PromptIndexTargets.end(),
        insertion="<image>" * image_feature_size,
    )
    ```
375
376
377
378
    """

    insertion: PromptUpdateContent = field(repr=False)
    """
379
380
381
382
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to insert right after
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
    """

    @property
    def content(self) -> PromptUpdateContent:
        return self.insertion

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.INSERT


@dataclass
class PromptReplacement(PromptUpdate):
399
400
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
401
402
403

    Example:

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    For each image, replace one ``<image>`` input placeholder in the prompt
    with a number of ``<image>`` feature placeholders
    equal to the feature size of the vision encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement="<image>" * image_feature_size,
    )
    ```

    As above, but further pad the feature placeholders with ``<image_bos>``
    and `<image_eos>``, which are not supposed to be passed to the vision
    encoder:

    ```python
    PromptReplacement(
        modality="image",
        target="<image>",
        replacement=PromptUpdateDetails(
425
426
427
428
429
430
431
            full="".join(
                [
                    "<image_bos>",
                    "<image>" * image_feature_size,
                    "<image_eos>",
                ]
            ),
432
433
434
435
436
437
438
439
440
441
442
443
444
            features="<image>" * image_feature_size,
        ),
    )
    ```

    To avoid unnecessary tokenization during prompt replacement,
    we recommended passing token sequences instead of text:

    ```python
    PromptReplacement(
        modality="image",
        target=[image_token_id],
        replacement=PromptUpdateDetails(
445
446
447
            full=(
                [image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id]
            ),
448
449
450
451
            features=[image_token_id] * image_feature_size,
        ),
    )
    ```
452
453
    """

454
    replacement: PromptUpdateContent = field(repr=False)
455
    """
456
457
458
459
    Given the index of the processed item within
    [`modality`][vllm.multimodal.processing.PromptUpdate.modality],
    output the token sequence (or text) to replace
    [`target`][vllm.multimodal.processing.PromptUpdate.target].
460

461
462
    For convenience, you can directly pass in the token sequence (or text)
    instead of a function if it does not depend on the input.
463
464
    """

465
466
467
468
469
470
471
    @property
    def content(self) -> PromptUpdateContent:
        return self.replacement

    @property
    def mode(self) -> UpdateMode:
        return UpdateMode.REPLACE
472
473


474
475
476
class _HasModalityAttr(Protocol):
    modality: str

477

478
479
class _HasModalityProp(Protocol):
    @property
480
    def modality(self) -> str: ...
481
482


483
_M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
484
485
486


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
487
488
    """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
    based on modality."""
489
490
491
    return full_groupby(values, key=lambda x: x.modality)


492
493
494
495
496
497
498
class PromptTargetMatch(NamedTuple):
    start_idx: int
    end_idx: int


@dataclass(frozen=True)
class ResolvedPromptUpdate:
499
    """
500
501
    A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its
    lazy attributes resolved, apart from those related to tokenization.
502
    """
503

504
505
    modality: str
    """The modality for which the update is made."""
506

507
508
    item_idx: int
    """The index within `modality` of the item this update pertains to."""
509

510
511
    mode: UpdateMode
    """Defines how to update the prompt."""
512

513
    target: UpdateTarget
514
    """The token sequence (or text) to update."""
515

516
    content: PromptUpdateDetails = field(repr=False)
517
    """The placeholder tokens that are part of the update."""
518

519
520
521
522
523
524
525
526
527
    def iter_token_matches(
        self,
        prompt: list[int],
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
528

529
530
531
532
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
533

534
            return
535

536
537
        target_token_ids = _seq2tokens(tokenizer, target)

538
        for match in iter_token_matches(prompt, target_token_ids, start_idx=start_idx):
539
            yield PromptTargetMatch(match.start_idx, match.end_idx)
540

541
542
543
544
545
546
547
548
549
    def iter_text_matches(
        self,
        prompt: str,
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        target = self.target
550

551
552
553
554
        if isinstance(target, PromptIndex):
            match_idx = target.get_match_index(tokenizer, prompt, start_idx)
            if match_idx is not None:
                yield PromptTargetMatch(match_idx, match_idx)
555

556
            return
557

558
559
        target_text = _seq2text(tokenizer, target)

560
        for match in re.finditer(re.escape(target_text), prompt, pos=start_idx):
561
562
563
564
            yield PromptTargetMatch(match.start(), match.end())

    def iter_matches(
        self,
565
        prompt: list[int] | str,
566
567
568
569
570
571
        tokenizer: AnyTokenizer,
        *,
        start_idx: int = 0,
    ) -> Generator[PromptTargetMatch]:
        """Yield each instance of `self.target` found in `prompt`."""
        if isinstance(prompt, str):
572
            return self.iter_text_matches(prompt, tokenizer, start_idx=start_idx)
573
574

        return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
575

576
577
578
579
580
581
582
583
584
    def with_target(self, target: UpdateTarget):
        return replace(self, target=target)

    def with_content(self, content: PromptUpdateInfo):
        if not isinstance(content, PromptUpdateDetails):
            content = PromptUpdateDetails.from_seq(content)

        return replace(self, content=content)

585

586
587
588
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
589
590


591
592
593
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
594
595
    *,
    start_idx: int = 0,
596
) -> Generator[_TokenMatch]:
597
    """
598
    Yield each occurrence of `match_ids` in `token_ids`.
599
600
601
602

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
603
    match_len = len(match_ids)
604

605
606
    if match_len == 0:
        return
607

608
    while start_idx < prompt_len - match_len + 1:
609
        end_idx = start_idx + match_len
610

611
612
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
613
614
615
616
617

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
618
619


620
621
622
623
624
625
def replace_token_matches(
    token_ids: list[int],
    match_ids: list[int],
    new_ids: list[int],
) -> list[int]:
    """
626
627
    Replace each occurrence of `match_ids` in `token_ids`
    with `new_ids`.
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

    Note that empty matches are ignored.
    """
    out_seqs = list[list[int]]()
    prev_end_idx = 0

    for match in iter_token_matches(token_ids, match_ids):
        start_idx = match.start_idx
        end_idx = match.end_idx

        out_seqs.append(token_ids[prev_end_idx:start_idx])
        out_seqs.append(new_ids)
        prev_end_idx = end_idx

    out_seqs.append(token_ids[prev_end_idx:])

    return flatten_2d_lists(out_seqs)


647
@dataclass
648
class PlaceholderFeaturesInfo:
649
    modality: str
650
    item_idx: int
651
    start_idx: int
652
    tokens: list[int]
653
    is_embed: torch.Tensor | None
654
655
656

    @property
    def length(self) -> int:
657
        return len(self.tokens)
658
659

    def to_range(self) -> PlaceholderRange:
660
661
        # TODO: Is it worth it to optimize this by stripping the
        # leading and ending positions where `is_embed=False`?
662
663
664
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
665
            is_embed=self.is_embed,
666
        )
667
668


669
_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
670
671


672
673
674
675
676
677
678
def _find_matches(
    prompt: _S,
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
    *,
    prev_end_idx: int = 0,
    current_result: "MultiModalPromptUpdatesApplyResult",
679
680
) -> tuple[UpdateMode | None, list[_MatchToApply]]:
    mode: UpdateMode | None = None
681
682
683
684
685
686
687
688
689
690
691
692
    mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]()

    for modality, modality_updates in mm_prompt_updates.items():
        for item_idx, item_updates in enumerate(modality_updates):
            if current_result[modality][item_idx] is not None:
                continue  # Updates have already been applied for this item

            for update_idx, update in enumerate(item_updates):
                if (modality, item_idx) in mm_matches:
                    break  # Already found a match for this item

                for match in update.iter_matches(
693
694
695
                    prompt,
                    tokenizer,
                    start_idx=prev_end_idx,
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                ):
                    # All matches should share the same mode
                    if mode is None:
                        mode = update.mode
                    elif mode != update.mode:
                        continue

                    mm_matches[(modality, item_idx)] = match, update_idx
                    break  # Get only the first valid match per item

    # Prioritize earlier matches
    matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0])

    # To avoid conflicts, only replace one non-empty item at a time
    if mode == UpdateMode.REPLACE:
        matches_to_apply_ = list[_MatchToApply]()
        has_non_empty_matches = False

        for item in matches_to_apply:
            _, (match, _) = item
            if match.start_idx == match.end_idx:
                matches_to_apply_.append(item)
            elif not has_non_empty_matches:
                has_non_empty_matches = True
                matches_to_apply_.append(item)

        matches_to_apply = matches_to_apply_

    return mode, matches_to_apply
725
726


727
def _apply_matches(
728
    prompt: _S,
729
730
731
732
733
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
    prompt_len = len(prompt)

734
    out_seqs = list[str | list[int]]()
735
    out_result: MultiModalPromptUpdatesApplyResult = {
736
        m: [None] * len(items) for m, items in mm_prompt_updates.items()
737
    }
738

739
740
741
    start_idx = prev_end_idx = 0
    while start_idx < max(prompt_len, 1):  # Allow inserts into empty prompt
        found = False
742

743
744
745
746
747
748
749
        mode, matches_to_apply = _find_matches(
            prompt,
            mm_prompt_updates,
            tokenizer,
            prev_end_idx=prev_end_idx,
            current_result=out_result,
        )
750

751
752
753
        if mode is not None:
            for (modality, item_idx), (match, update_idx) in matches_to_apply:
                found = True
754

755
                matched_update = mm_prompt_updates[modality][item_idx][update_idx]
756
                matched_content = matched_update.content.full
757

758
759
760
761
762
763
                if mode == UpdateMode.INSERT:
                    end_idx_to_insert = match.end_idx
                elif mode == UpdateMode.REPLACE:
                    end_idx_to_insert = match.start_idx
                else:
                    assert_never(mode)
764

765
                out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
766
                out_seqs.append(
767
768
769
770
                    _seq2text(tokenizer, matched_content)
                    if isinstance(prompt, str)
                    else _seq2tokens(tokenizer, matched_content)
                )
771
                out_result[modality][item_idx] = update_idx
772

773
774
775
776
777
                # Exclude overlapping matches
                start_idx = prev_end_idx = match.end_idx

        if not found:
            start_idx += 1
778
779
780

    out_seqs.append(prompt[prev_end_idx:])

781
    return cast(list[_S], out_seqs), out_result
782
783


784
def apply_token_matches(
785
    prompt: list[int],
786
787
788
789
790
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
791

792
793
794
795
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
796
    token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
797

798
    return flatten_2d_lists(token_id_seqs), result
799
800


801
def apply_text_matches(
802
    prompt: str,
803
804
805
806
807
    mm_prompt_updates: "MultiModalPromptUpdates",
    tokenizer: AnyTokenizer,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
    """
    Apply the updates in `mm_prompt_updates` to `prompt`.
808

809
810
811
812
813
    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_updates` takes priority.
    """
    texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer)
814

815
    return "".join(texts), result
816
817


818
def _iter_placeholders(
819
    prompt: list[int],
820
    mm_prompt_updates: "MultiModalPromptUpdates",
821
    tokenizer: AnyTokenizer,
822
) -> Iterable[PlaceholderFeaturesInfo]:
823
    """
824
    Yield each set of placeholder tokens found in `prompt`.
825
826
827

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
828
    appears earlier in `mm_prompt_updates` takes priority.
829

830
831
    Note that empty matches are ignored.
    """
832
    prompt_len = len(prompt)
833
834
    mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}

835
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
836
837
838
839
840

    start_idx = 0
    while start_idx < prompt_len:
        found = False

841
        for modality, modality_updates in mm_prompt_updates.items():
842
843
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
844
                continue
845

846
847
            for update in modality_updates[item_idx]:
                content = update.content
848
                content_tokens_full = _seq2tokens(tokenizer, content.full)
849
850
                content_len_full = len(content_tokens_full)
                end_idx_full = start_idx + content_len_full
851

852
                if content_len_full == 0 or end_idx_full > prompt_len:
853
854
                    continue

855
                if prompt[start_idx:end_idx_full] == content_tokens_full:
856
857
                    content_is_embed = content.is_embed
                    if content_is_embed is not None:
858
                        content_is_embed = content_is_embed(tokenizer, content.full)
859
860
861
862
863
864
865
866

                    yield PlaceholderFeaturesInfo(
                        modality=modality,
                        item_idx=item_idx,
                        start_idx=start_idx,
                        tokens=content_tokens_full,
                        is_embed=content_is_embed,
                    )
867

868
                    # Exclude overlapping matches
869
                    start_idx = end_idx_full
870
871
872
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
873

874
875
            if found:
                break  # Go back to the outer while loop
876
877
878

        if not found:
            start_idx += 1
879
880


881
882
def find_mm_placeholders(
    prompt: list[int],
883
    mm_prompt_updates: "MultiModalPromptUpdates",
884
    tokenizer: AnyTokenizer,
885
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
886
    it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
887
888
889
    return dict(full_groupby_modality(it))


890
_T = TypeVar("_T")
891
892
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
893
894
895
896
897
898
899
900
901


@dataclass(frozen=True)
class InputProcessingContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

902
    model_config: ModelConfig
903
904
905
906
907
908
    """The configuration of the model."""

    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

    @overload
909
    def get_hf_config(self, /) -> PretrainedConfig: ...
910
911
912
913

    @overload
    def get_hf_config(
        self,
914
        typ: type[_C] | tuple[type[_C], ...],
915
        /,
916
    ) -> _C: ...
917
918
919

    def get_hf_config(
        self,
920
        typ: type[Any] | tuple[type[Any], ...] | None = None,
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        /,
    ) -> Any:
        """
        Get the HuggingFace configuration
        (`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the configuration is not of the specified type.
        """
        if typ is None:
            from transformers.configuration_utils import PretrainedConfig

            typ = PretrainedConfig

        hf_config = self.model_config.hf_config
        if not isinstance(hf_config, typ):
938
939
940
941
942
            raise TypeError(
                "Invalid type of HuggingFace config. "
                f"Expected type: {typ}, but "
                f"found type: {type(hf_config)}"
            )
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965

        return hf_config

    def get_hf_image_processor_config(self) -> dict[str, Any]:
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

    @overload
966
    def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ...
967
968
969
970

    @overload
    def get_hf_processor(
        self,
971
        typ: type[_P] | tuple[type[_P], ...],
972
973
        /,
        **kwargs: object,
974
    ) -> _P: ...
975
976
977

    def get_hf_processor(
        self,
978
        typ: type[Any] | tuple[type[Any], ...] | None = None,
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        /,
        **kwargs: object,
    ) -> Any:
        """
        Get the HuggingFace processor
        (`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
        if typ is None:
            from transformers.processing_utils import ProcessorMixin

            typ = ProcessorMixin

        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            tokenizer=self.tokenizer,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

        return typ(**merged_kwargs)

    def _postprocess_output(
        self,
        output: JSONTree,
    ) -> JSONTree:
        def _postprocess_one(x: object):
            if isinstance(x, torch.Tensor):  # noqa: SIM102
                # This mimics the behavior of transformers.BatchFeature
                if x.is_floating_point():
                    x = x.to(dtype=self.model_config.dtype)

            return x

        return json_map_leaves(_postprocess_one, output)

    def call_hf_processor(
        self,
1037
        hf_processor: ProcessorMixin,
1038
1039
1040
1041
1042
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
        *,
        num_tries: int = 1,
        max_tries: int = 5,
1043
    ) -> BatchFeature | JSONTree:
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        """
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
        """
        assert callable(hf_processor)

        mm_config = self.model_config.get_multimodal_config()
        merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)

        allowed_kwargs = get_allowed_kwarg_only_overrides(
            hf_processor,
            merged_kwargs,
            requires_kw_only=False,
            allow_var_kwargs=True,
        )

        try:
1061
            output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
1062
1063
        except Exception as exc:
            # See https://github.com/huggingface/tokenizers/issues/537
1064
1065
1066
1067
1068
1069
            if (
                isinstance(exc, RuntimeError)
                and exc
                and exc.args[0] == "Already borrowed"
                and num_tries < max_tries
            ):
1070
1071
                logger.warning(
                    "Failed to acquire tokenizer in current thread. "
1072
1073
1074
1075
                    "Retrying (%d/%d)...",
                    num_tries,
                    max_tries,
                )
1076
1077
1078
1079
1080
1081
1082
1083
1084
                time.sleep(0.5)
                return self.call_hf_processor(
                    hf_processor,
                    data,
                    kwargs,
                    num_tries=num_tries + 1,
                    max_tries=max_tries,
                )

1085
1086
1087
1088
            msg = (
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={allowed_kwargs}"
            )
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

            raise ValueError(msg) from exc

        # this emulates output.to(dtype=self.model_config.dtype)
        from transformers.feature_extraction_utils import BatchFeature

        if isinstance(output, BatchFeature):
            output_ = self._postprocess_output(output.data)
            return BatchFeature(output_)

        logger.warning_once(
            "%s did not return `BatchFeature`. "
            "Make sure to match the behaviour of `ProcessorMixin` when "
            "implementing custom processors.",
            type(hf_processor).__name__,
        )

        return self._postprocess_output(output)


1109
class BaseProcessingInfo:
1110
    """Base class to provide the information necessary for data processing."""
1111

1112
1113
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
1114

1115
1116
1117
1118
1119
1120
1121
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
1122
1123
        return self.ctx.tokenizer

1124
    def get_hf_config(self) -> PretrainedConfig:
1125
1126
        return self.ctx.get_hf_config()

1127
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1128
1129
1130
1131
1132
1133
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

1134
    @abstractmethod
1135
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

1146
1147
1148
1149
1150
1151
1152
1153
1154
    def get_allowed_mm_limits(self) -> Mapping[str, int]:
        """Return the maximum allowed number of items for each modality."""
        supported_mm_limits = self.get_supported_mm_limits()
        mm_config = self.ctx.get_mm_config()

        allowed_limits = dict[str, int]()
        for modality, supported_limit in supported_mm_limits.items():
            user_limit = mm_config.get_limit_per_prompt(modality)

1155
1156
1157
1158
1159
            allowed_limits[modality] = (
                user_limit
                if supported_limit is None
                else min(user_limit, supported_limit)
            )
1160
1161
1162

        return allowed_limits

1163
1164
1165
1166
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1167
    ) -> Mapping[str, int] | None:
1168
1169
        """
        Return the maximum number of tokens per item of for each modality.
1170

1171
1172
1173
1174
        When `None` (the default) is returned, vLLM will generate dummy inputs
        (images/videos) at maximum possible sizes and process them to determine
        the maximum token count per modality.

1175
1176
1177
1178
1179
        This approach works but can be very slow for certain models (e.g.,
        Qwen2.5-VL), leading to very long startup time. For better performance,
        each model can override this method to return pre-computed maximum token
        counts, avoiding the need for dummy input generation and processing.

1180
        Note:
1181
            The maximum number of tokens per item of each modality returned
1182
1183
1184
1185
            from this function should respect the model's maximum sequence
            length and the maximum number of items of each modality allowed,
            and agree with dummy inputs (images/videos) at maximum possible
            sizes.
1186
1187
1188
        """
        return None

1189
1190

_I = TypeVar("_I", bound=BaseProcessingInfo)
1191

1192
1193
MultiModalHashes = dict[str, list[str]]
"""
1194
A collection of hashes with a similar structure as
1195
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
1196
1197
"""

1198
MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]]
1199
1200
1201
1202
1203
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""

1204
MultiModalPromptUpdatesApplyResult = Mapping[str, list[int | None]]
1205
1206
1207
1208
1209
1210
1211
"""
For an item `MultiModalPromptUpdates[k][i]`,
`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the
`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the
`ResolvedPromptUpdate` instances have been applied.
"""

1212
1213

class MultiModalProcessingInfo(NamedTuple):
1214
    kwargs: MultiModalKwargsOptionalItems
1215
    hashes: MultiModalHashes
1216
1217
    prompt_updates: MultiModalPromptUpdates

1218
1219

class BaseMultiModalProcessor(ABC, Generic[_I]):
1220
    """
1221
    Abstract base class to process multi-modal inputs to be used in vLLM.
1222

1223
    Not to be confused with `transformers.ProcessorMixin`.
1224
1225
    """

1226
1227
1228
1229
1230
    def __init__(
        self,
        info: _I,
        dummy_inputs: "BaseDummyInputsBuilder[_I]",
        *,
1231
        cache: BaseMultiModalProcessorCache | None = None,
1232
    ) -> None:
1233
1234
        super().__init__()

1235
1236
        self.info = info
        self.dummy_inputs = dummy_inputs
1237
        self.cache = cache
1238

1239
1240
        self.data_parser = self._get_data_parser()

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
        # Avoid unnecessary recomputation
        self._supported_mm_limits = self.info.get_supported_mm_limits()
        self._allowed_mm_limits = self.info.get_allowed_mm_limits()

    @property
    def supported_mm_limits(self):
        return self._supported_mm_limits

    @property
    def allowed_mm_limits(self):
        return self._allowed_mm_limits

1253
    def __call__(
1254
        self,
1255
1256
        prompt: str,
        mm_data: MultiModalDataDict,
1257
        hf_processor_mm_kwargs: Mapping[str, object],
1258
        *,
1259
        mm_uuids: MultiModalUUIDDict | None = None,
1260
    ) -> MultiModalInputs:
1261
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
1262

1263
1264
    def _get_data_parser(self) -> MultiModalDataParser:
        """
1265
        Construct a parser to preprocess multi-modal data items
1266
1267
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1268
1269

        You can support additional modalities by creating a subclass
1270
1271
        of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
        that has additional subparsers.
1272
1273
1274
        """
        return MultiModalDataParser()

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
    def validate_num_items(
        self,
        modality: str,
        num_items: int,
    ) -> None:
        supported_limit = self.supported_mm_limits.get(modality, 0)
        allowed_limit = self.allowed_mm_limits.get(modality, 0)

        if supported_limit is None:
            supported_limit = allowed_limit

        limit = min(supported_limit, allowed_limit)

        if num_items > limit:
1289
            msg = f"At most {limit} {modality}(s) may be provided in one prompt."
1290
1291
1292
1293
1294
1295

            if num_items <= supported_limit:
                msg += " Set `--limit-mm-per-prompt` to increase this limit."

            raise ValueError(msg)

1296
    def _to_mm_items(
1297
1298
1299
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
1300
        """
1301
1302
1303
1304
1305
        Normalize
        [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
        to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
        before passing them to
        [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
1306
        """
1307
        mm_items = self.data_parser.parse_mm_data(mm_data)
1308
        for modality, items in mm_items.items():
1309
            self.validate_num_items(modality, len(items))
1310
1311

        return mm_items
1312

1313
1314
1315
    @abstractmethod
    def _get_mm_fields_config(
        self,
1316
        hf_inputs: BatchFeature,
1317
1318
1319
1320
1321
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

1322
    @abstractmethod
1323
    def _get_prompt_updates(
1324
        self,
1325
        mm_items: MultiModalDataItems,
1326
        hf_processor_mm_kwargs: Mapping[str, object],
1327
        out_mm_kwargs: MultiModalKwargsItems,
1328
    ) -> Sequence[PromptUpdate]:
1329
1330
        """
        Given the original multi-modal items for this modality
1331
        and HF-processed data, output the updates to perform.
1332

1333
1334
1335
1336
1337
1338
        The information returned by this method is used to update token inputs
        which bypass the HF processor. It is also used to update the output of
        HF processor if the HF process does not apply prompt updates to text
        inputs.

        Moreover, this information is critical to determine the token positions
1339
1340
        in order to construct
        [`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
1341
        for each multi-modal item.
1342
1343
        """
        raise NotImplementedError
1344

1345
1346
1347
1348
1349
1350
    def _bind_and_group_updates(
        self,
        prompt_updates: Sequence[PromptUpdate],
        mm_item_counts: Mapping[str, int],
    ) -> MultiModalPromptUpdates:
        return {
1351
1352
1353
1354
            modality: [
                [update.resolve(item_idx) for update in updates]
                for item_idx in range(mm_item_counts.get(modality, 0))
            ]
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
            for modality, updates in full_groupby_modality(prompt_updates)
        }

    def _get_mm_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> MultiModalPromptUpdates:
        unbound_prompt_updates = self._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )

        mm_prompt_updates = self._bind_and_group_updates(
            unbound_prompt_updates,
            mm_items.get_all_counts(),
        )

        for modality, prompt_updates in mm_prompt_updates.items():
            for item_idx, item_prompt_updates in enumerate(prompt_updates):
                if len(item_prompt_updates) > 1:
                    logger.warning_once(
                        "Detected %d prompt updates for `mm_items[%r][%s]`. "
                        "Multiple prompt updates per item is now "
                        "deprecated and may be removed in v0.13. "
                        "Instead, please specify dynamic update targets "
                        "in the same prompt update definition by passing "
                        "a function to `PromptUpdate.target`.",
                        len(prompt_updates),
                        modality,
                        item_idx,
                    )

        return mm_prompt_updates

1392
    def _find_mm_placeholders(
1393
1394
        self,
        new_token_ids: list[int],
1395
        mm_prompt_updates: MultiModalPromptUpdates,
1396
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
1397
1398
        tokenizer = self.info.get_tokenizer()

1399
        return find_mm_placeholders(new_token_ids, mm_prompt_updates, tokenizer)
1400

1401
    def _get_hf_mm_data(
1402
        self,
1403
        mm_items: MultiModalDataItems,
1404
1405
1406
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
1407

1408
1409
1410
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
1411

1412
1413
        return processor_data, passthrough_data

1414
1415
1416
    def _call_hf_processor(
        self,
        prompt: str,
1417
1418
1419
1420
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
1421
        tok_kwargs: Mapping[str, object],
1422
    ) -> BatchFeature:
1423
1424
1425
1426
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
1427
1428
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1429
            dict(text=prompt, **mm_data),
1430
            dict(**mm_kwargs, **tok_kwargs),
1431
1432
        )

1433
    def _hf_processor_applies_updates(
1434
1435
1436
1437
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1438
        tokenization_kwargs: Mapping[str, object],
1439
1440
    ) -> bool:
        """
1441
        Return whether the HF processor applies prompt updates.
1442

1443
1444
        For most HF processors, this should be `True` when multi-modal
        data items are passed, but `False` when multi-modal embeddings
1445
1446
1447
1448
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
1449
1450
            for items in mm_items.values()
        )
1451

1452
    def _apply_hf_processor_text_mm(
1453
        self,
1454
        prompt_text: str,
1455
        mm_items: MultiModalDataItems,
1456
        hf_processor_mm_kwargs: Mapping[str, object],
1457
        tokenization_kwargs: Mapping[str, object],
1458
    ) -> tuple[list[int], BatchFeature, bool]:
1459
        """
1460
1461
        Apply the HF processor on the prompt text and multi-modal data
        together.
1462

1463
        In addition, return whether prompt updates have been applied.
1464
1465
1466
1467
1468
1469
1470
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
1471
            tok_kwargs=tokenization_kwargs,
1472
1473
        )
        processed_data.update(passthrough_data)
1474

1475
        (prompt_ids,) = processed_data.pop("input_ids").tolist()
1476

1477
        is_update_applied = self._hf_processor_applies_updates(
1478
1479
1480
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1481
            tokenization_kwargs=tokenization_kwargs,
1482
1483
        )

1484
        return prompt_ids, processed_data, is_update_applied
1485

1486
    def _apply_hf_processor_text_only(
1487
1488
1489
1490
        self,
        prompt_text: str,
        tokenization_kwargs: Mapping[str, object],
    ) -> list[int]:
1491
        """
1492
        Apply the HF processor on the prompt text only.
1493

1494
1495
1496
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
1497
        """
1498
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
1499
1500
1501
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
1502
            tokenization_kwargs=tokenization_kwargs,
1503
1504
        )

1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
1517
1518
1519
        with the output of
        [`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
        on the
1520
1521
1522
1523
1524
1525
1526
1527
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1528
        tokenization_kwargs: Mapping[str, object],
1529
    ) -> BatchFeature:
1530
1531
1532
1533
1534
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
1535
1536
        [`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
        to go along with the multi-modal data.
1537
1538
1539
        """
        mm_counts = mm_items.get_all_counts()

1540
        _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
1541
            prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
1542
1543
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1544
            tokenization_kwargs=tokenization_kwargs,
1545
1546
        )

1547
        return mm_processed_data
1548
1549
1550

    def _apply_hf_processor_main(
        self,
1551
        prompt: str | list[int],
1552
1553
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1554
        tokenization_kwargs: Mapping[str, object],
1555
        *,
1556
        enable_hf_prompt_update: bool,
1557
    ) -> tuple[list[int], BatchFeature, bool]:
1558
1559
1560
        """
        Apply the HF processor on the prompt text and multi-modal data.

1561
        In addition, return whether prompt updates have been applied
1562
        (for most HF processors, this should be `True`).
1563

1564
        Note:
1565
            If `enable_hf_prompt_update=False`, we use HF processor
1566
            to perform prompt updates if available; HF processor requires
1567
            that the prompt corresponds to multi-modal items.
1568
1569
        """
        if isinstance(prompt, str):
1570
            if enable_hf_prompt_update:
1571
1572
1573
1574
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1575
                    tokenization_kwargs=tokenization_kwargs,
1576
1577
                )

1578
            prompt_ids = self._apply_hf_processor_text_only(prompt, tokenization_kwargs)
1579
1580
1581
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

1582
        mm_processed_data = self._apply_hf_processor_mm_only(
1583
            mm_items=mm_items,
1584
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1585
            tokenization_kwargs=tokenization_kwargs,
1586
1587
        )

1588
        return prompt_ids, mm_processed_data, False
1589

1590
    def _hash_mm_items(
1591
1592
1593
1594
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
1595
        *,
1596
        mm_uuids: MultiModalUUIDDict | None = None,
1597
    ) -> MultiModalHashes:
1598
        """Create MM hashes to be returned.
1599

1600

1601
1602
1603
        Note: When overrides are provided via callers of `apply`,
        `_hash_mm_items` will be bypassed and the overrides will be used.
        """
1604
1605
        model_id = self.info.model_id

1606
        hashes: MultiModalHashes = {}
1607
        mm_uuids = mm_uuids or {}
1608
1609

        for modality, items in mm_items.items():
1610
1611
1612
1613
            if modality in mm_uuids:
                mm_uuids_per_modality = mm_uuids[modality]
                if isinstance(mm_uuids_per_modality, str):
                    mm_uuids_per_modality = [mm_uuids_per_modality]
1614
1615
1616
1617

                # For None entries, compute a hash; otherwise, use provided ID.
                computed: list[str] = []
                for i, item in enumerate(items):
1618
                    item_uuid = mm_uuids_per_modality[i]
1619

1620
                    # NOTE: Even if a item_uuid is provided, we still compute a
1621
1622
1623
                    # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
                    # are provided. This is because the processed multimodal
                    # inputs can be different depending on the processor kwargs.
1624
1625
1626
1627
1628
                    if (
                        item_uuid is None
                        or hf_processor_mm_kwargs
                        or tokenization_kwargs
                    ):
1629
1630
                        # NOTE: use provided hash string to hash with kwargs
                        # if available for better performance.
1631
                        item = item_uuid if item_uuid is not None else item
1632
1633
1634
1635
1636
                        computed.append(
                            MultiModalHasher.hash_kwargs(
                                model_id=model_id,
                                **{modality: item},
                                **hf_processor_mm_kwargs,
1637
1638
1639
                                **tokenization_kwargs,
                            )
                        )
1640
                    else:
1641
                        computed.append(item_uuid)
1642
1643
1644
                hashes[modality] = computed
            else:
                hashes[modality] = [
1645
1646
1647
1648
1649
1650
                    MultiModalHasher.hash_kwargs(
                        model_id=model_id,
                        **{modality: item},
                        **hf_processor_mm_kwargs,
                        **tokenization_kwargs,
                    )
1651
1652
1653
1654
                    for item in items
                ]

        return hashes
1655

1656
1657
    def _get_cache_missing_items(
        self,
1658
        cache: BaseMultiModalProcessorCache,
1659
1660
1661
1662
        mm_data_items: MultiModalDataItems,
        mm_hashes: MultiModalHashes,
    ) -> MultiModalDataItems:
        mm_is_cached = {
1663
            modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
1664
1665
1666
1667
        }

        mm_missing_idxs = {
            modality: [
1668
1669
                idx
                for idx, item_is_cached in enumerate(items_is_cached)
1670
1671
1672
1673
                if not item_is_cached
            ]
            for modality, items_is_cached in mm_is_cached.items()
        }
1674
1675
1676
1677
1678
1679
1680
1681
        mm_missing_data = {}
        for modality, idxs in mm_missing_idxs.items():
            missing_modality_data = []
            for idx in idxs:
                data = mm_data_items[modality][idx]
                if data is None:
                    raise ValueError(
                        f"Cache miss for {modality} at index {idx} "
1682
1683
                        f"but data is not provided."
                    )
1684
1685
1686
                else:
                    missing_modality_data.append(data)
            mm_missing_data[modality] = missing_modality_data
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700

        return self._to_mm_items(mm_missing_data)

    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        """
        Override this if other attributes of `ResolvedPromptUpdate`
        also need to be recomputed after retrieving from the cache.
        """
        return replace(cached_update, item_idx=new_item_idx)

1701
1702
    def _merge_mm_kwargs(
        self,
1703
        cache: BaseMultiModalProcessorCache,
1704
        mm_hashes: MultiModalHashes,
1705
        mm_missing_kwargs: MultiModalKwargsItems,
1706
1707
1708
1709
1710
        mm_missing_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
        # Need to calculate this at the beginning to avoid skipping cache logic
        # for subsequently repeated items in the same modality
        mm_is_cached = {
1711
            modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
1712
1713
        }

1714
        mm_missing_next_idx = defaultdict[str, int](lambda: 0)
1715

1716
        merged_kwargs = defaultdict[str, list[MultiModalKwargsItem | None]](list)
1717
1718
1719
        merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](
            list
        )
1720
1721
        for modality, hashes in mm_hashes.items():
            missing_kwargs = mm_missing_kwargs.get(modality, [])
1722
            missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
1723
1724

            for item_idx, item_hash in enumerate(hashes):
1725
                kwargs: MultiModalKwargsItem | None
1726
1727
1728
1729
1730
                if not mm_is_cached[modality][item_idx]:
                    missing_next_idx = mm_missing_next_idx[modality]
                    kwargs = missing_kwargs[missing_next_idx]
                    updates = missing_prompt_updates[missing_next_idx]

1731
                    mm_missing_next_idx[modality] += 1
1732
1733

                    item = kwargs, updates
1734
                else:
1735
1736
1737
1738
1739
                    item = None

                kwargs, updates = cache.get_and_update_item(item, item_hash)

                merged_kwargs[modality].append(kwargs)
1740
1741
1742
1743
1744
1745
                merged_prompt_updates[modality].append(
                    [
                        self._recompute_cached_prompt_update(update, item_idx)
                        for update in updates
                    ]
                )
1746

1747
1748
        mm_kwargs = MultiModalKwargsItems(merged_kwargs)
        mm_prompt_updates = dict(merged_prompt_updates)
1749

1750
        return mm_kwargs, mm_prompt_updates
1751
1752
1753

    def _apply_hf_processor(
        self,
1754
        prompt: str | list[int],
1755
1756
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1757
        tokenization_kwargs: Mapping[str, object],
1758
        *,
1759
        mm_uuids: MultiModalUUIDDict | None = None,
1760
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1761
1762
        (
            prompt_ids,
1763
            mm_processed_data,
1764
1765
1766
1767
1768
            is_update_applied,
        ) = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1769
            tokenization_kwargs=tokenization_kwargs,
1770
1771
1772
            enable_hf_prompt_update=True,
        )

1773
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
1774
            mm_processed_data,
1775
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
1776
1777
        )

1778
        # Use overrides if provided; fallback to data-dependent hashing.
1779
1780
1781
1782
1783
1784
        mm_hashes = self._hash_mm_items(
            mm_data_items,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
1785

1786
        mm_prompt_updates = self._get_mm_prompt_updates(
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
            mm_data_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
            hashes=mm_hashes,
            prompt_updates=mm_prompt_updates,
        )

        return prompt_ids, mm_info, is_update_applied
1799

1800
1801
    def _cached_apply_hf_processor(
        self,
1802
        prompt: str | list[int],
1803
1804
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1805
        tokenization_kwargs: Mapping[str, object],
1806
        *,
1807
        mm_uuids: MultiModalUUIDDict | None = None,
1808
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
1809
1810
1811
1812
1813
1814
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache

1815
1816
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
1817
            return self._apply_hf_processor(
1818
                prompt=prompt,
1819
                mm_data_items=mm_data_items,
1820
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1821
                tokenization_kwargs=tokenization_kwargs,
1822
                mm_uuids=mm_uuids,
1823
1824
            )

1825
1826
1827
1828
1829
1830
        mm_hashes = self._hash_mm_items(
            mm_data_items,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
1831
1832

        mm_missing_data_items = self._get_cache_missing_items(
1833
1834
            cache=cache,
            mm_data_items=mm_data_items,
1835
            mm_hashes=mm_hashes,
1836
        )
1837

1838
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1839
        # so we can't apply prompt updates until the new multimodal
1840
1841
1842
        # items are combined with the cached multimodal items
        (
            prompt_ids,
1843
            mm_missing_processed_data,
1844
            is_update_applied,
1845
        ) = self._apply_hf_processor_main(
1846
            prompt=prompt,
1847
            mm_items=mm_missing_data_items,
1848
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1849
            tokenization_kwargs=tokenization_kwargs,
1850
            enable_hf_prompt_update=False,
1851
1852
        )

1853
        mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
1854
            mm_missing_processed_data,
1855
1856
1857
            self._get_mm_fields_config(
                mm_missing_processed_data, hf_processor_mm_kwargs
            ),
1858
1859
        )

1860
1861
1862
1863
        mm_missing_prompt_updates = self._get_mm_prompt_updates(
            mm_missing_data_items,
            hf_processor_mm_kwargs,
            mm_missing_kwargs,
1864
        )
1865

1866
1867
1868
1869
1870
        mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
            cache,
            mm_hashes=mm_hashes,
            mm_missing_kwargs=mm_missing_kwargs,
            mm_missing_prompt_updates=mm_missing_prompt_updates,
1871
1872
1873
1874
        )

        mm_info = MultiModalProcessingInfo(
            kwargs=mm_kwargs,
1875
            hashes=mm_hashes,
1876
1877
            prompt_updates=mm_prompt_updates,
        )
1878

1879
        return prompt_ids, mm_info, is_update_applied
1880

1881
1882
1883
    def _apply_token_matches(
        self,
        prompt: list[int],
1884
1885
1886
1887
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_token_matches(prompt, mm_prompt_updates, tokenizer)
1888
1889
1890
1891

    def _apply_text_matches(
        self,
        prompt: str,
1892
1893
1894
1895
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[str, MultiModalPromptUpdatesApplyResult]:
        tokenizer = self.info.get_tokenizer()
        return apply_text_matches(prompt, mm_prompt_updates, tokenizer)
1896

1897
    def _apply_prompt_updates(
1898
1899
        self,
        token_ids: list[int],
1900
        mm_prompt_updates: MultiModalPromptUpdates,
1901
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
1902
        tokenizer = self.info.get_tokenizer()
1903

1904
1905
1906
1907
        new_token_ids, match_result = self._apply_token_matches(
            token_ids,
            mm_prompt_updates,
        )
1908
1909
1910
1911
1912
1913
1914
1915
1916

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
1917
1918
        # of the search text in the prompt, we instead perform string-based
        # updates on the decoded token IDs, then encode them back.
1919
        if not all(
1920
1921
1922
            all(update_idx is not None for update_idx in update_idxs)
            for update_idxs in match_result.values()
        ):
1923
1924
1925
            new_text, match_result = self._apply_text_matches(
                decode_tokens(tokenizer, token_ids),
                mm_prompt_updates,
1926
1927
            )

1928
1929
1930
1931
            new_token_ids = encode_tokens(
                tokenizer,
                new_text,
                add_special_tokens=False,
1932
1933
            )

1934
        matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list)
1935
1936
1937
1938
        for modality, update_idxs in match_result.items():
            for item_idx, update_idx in enumerate(update_idxs):
                assert update_idx is not None, (
                    "Failed to apply prompt replacement for "
1939
1940
                    f"mm_items[{modality!r}][{item_idx}]"
                )
1941
1942

                matched_updates[modality].append(
1943
1944
                    [mm_prompt_updates[modality][item_idx][update_idx]]
                )
1945
1946

        placeholders = self._find_mm_placeholders(
1947
1948
            new_token_ids,
            dict(matched_updates),
1949
        )
1950

1951
        return new_token_ids, placeholders
1952

1953
1954
    def _validate_mm_kwargs(
        self,
1955
        mm_kwargs: MultiModalKwargsOptionalItems,
1956
1957
1958
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
1959
            items = mm_kwargs.get(modality, [])
1960
1961
1962
1963
1964
1965
1966
1967
1968

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
1969
1970
                    "`_call_hf_processor` and `_get_mm_fields_config`)."
                )
1971

1972
    def _validate_mm_updates(
1973
        self,
1974
        mm_updates: MultiModalPromptUpdates,
1975
        mm_item_counts: Mapping[str, int],
1976
    ) -> None:
1977
        for modality, item_count in mm_item_counts.items():
1978
            placeholders = mm_updates.get(modality, [])
1979

1980
            if len(placeholders) != item_count:
1981
                raise RuntimeError(
1982
                    f"Expected there to be {item_count} prompt updates "
1983
                    f"corresponding to {item_count} {modality} items, but "
1984
                    f"instead found {len(placeholders)} prompt updates! "
1985
1986
1987
                    "This is likely because you forgot to include input "
                    "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
                    "in the prompt. If the model has a chat template, make "
1988
1989
                    "sure you have applied it before calling `LLM.generate`."
                )
1990

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt placeholders "
                    f"corresponding to {item_count} {modality} items, but "
                    f"instead found {len(placeholders)} prompt placeholders! "
                    "Make sure the implementation of `_call_hf_processor` and "
2005
2006
                    "`_get_mm_fields_config` are consistent with each other."
                )
2007

2008
2009
2010
2011
    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
2012
        mm_kwargs: MultiModalKwargsOptionalItems,
2013
        mm_prompt_updates: MultiModalPromptUpdates,
2014
        is_update_applied: bool,
2015
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
2016
        mm_item_counts = mm_items.get_all_counts()
2017
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
2018
        self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
2019

2020
        if is_update_applied:
2021
2022
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
2023
                mm_prompt_updates,
2024
            )
2025
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
2026
        else:
2027
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
2028
                prompt_ids,
2029
                mm_prompt_updates,
2030
            )
2031
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
2032

2033
        return prompt_ids, mm_placeholders
2034
2035
2036

    def apply(
        self,
2037
        prompt: str | list[int],
2038
2039
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
2040
        tokenization_kwargs: Mapping[str, object] | None = None,
2041
        *,
2042
        mm_uuids: MultiModalUUIDDict | None = None,
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
    ) -> MultiModalInputs:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and update sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        mm_items = self._to_mm_items(mm_data)

2059
2060
2061
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

2062
2063
        (
            prompt_ids,
2064
            mm_info,
2065
2066
2067
2068
2069
            is_update_applied,
        ) = self._cached_apply_hf_processor(
            prompt,
            mm_items,
            hf_processor_mm_kwargs,
2070
            tokenization_kwargs=tokenization_kwargs,
2071
            mm_uuids=mm_uuids,
2072
2073
        )

2074
        # NOTE: tokenization_kwargs are not required to init processor
2075
        prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
2076
2077
            mm_items=mm_items,
            prompt_ids=prompt_ids,
2078
2079
            mm_kwargs=mm_info.kwargs,
            mm_prompt_updates=mm_info.prompt_updates,
2080
2081
2082
            is_update_applied=is_update_applied,
        )

2083
2084
2085
2086
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
2087

2088
        return MultiModalInputs(
2089
            type="multimodal",
2090
            prompt_token_ids=prompt_ids,
2091
2092
            mm_kwargs=mm_info.kwargs,
            mm_hashes=mm_info.hashes,
2093
            mm_placeholders=mm_placeholder_ranges,
2094
        )
2095
2096
2097
2098
2099
2100


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
    @abstractmethod
    def create_encoder_prompt(
        self,
2101
        prompt: str | list[int],
2102
        mm_data: MultiModalDataDict,
2103
    ) -> str | list[int]:
2104
        """
2105
        Create input prompt for the encoder. HF processor will be applied on
2106
2107
        this prompt during profiling and generation.
        """
2108
2109
        raise NotImplementedError

2110
2111
2112
2113
    @property
    def pad_dummy_encoder_prompt(self) -> bool:
        return False

2114
2115
    def create_decoder_prompt(
        self,
2116
        prompt: str | list[int],
2117
        mm_data: MultiModalDataDict,
2118
    ) -> str | list[int]:
2119
2120
2121
        """Create input prompt for the decoder."""
        return prompt

2122
    def _get_enc_dec_inputs(
2123
        self,
2124
        prompt: str | list[int],
2125
        mm_data: MultiModalDataDict,
2126
2127
        encoder_inputs: MultiModalInputs,
    ):
2128
        tokenizer = self.info.get_tokenizer()
2129
2130
        decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt_raw, str):
2131
2132
2133
            decoder_prompt_ids = encode_tokens(
                tokenizer, decoder_prompt_raw, add_special_tokens=False
            )
2134
        else:
2135
            decoder_prompt_ids = decoder_prompt_raw
2136
2137
2138

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
2139
2140
            **encoder_inputs,
        )
2141
        mm_inputs["prompt_token_ids"] = decoder_prompt_ids
2142
        return mm_inputs
2143
2144
2145

    def apply(
        self,
2146
        prompt: str | list[int],
2147
2148
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
2149
        tokenization_kwargs: Mapping[str, object] | None = None,
2150
        *,
2151
        mm_uuids: MultiModalUUIDDict | None = None,
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
2165
            tokenization_kwargs,
2166
            mm_uuids=mm_uuids,
2167
2168
2169
2170
2171
2172
2173
        )

        return self._get_enc_dec_inputs(
            prompt=prompt,
            mm_data=mm_data,
            encoder_inputs=encoder_inputs,
        )