processing.py 44.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import re
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from functools import lru_cache
10
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
12

13
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
14

15
16
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
17
from vllm.logger import init_logger
18
19
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
20
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
21

22
from .hasher import MultiModalHasher
23
24
25
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
26
from .parse import MultiModalDataItems, MultiModalDataParser
27
28
29

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
30

31
logger = init_logger(__name__)
32
33

_S = TypeVar("_S", str, list[int])
34
35
36

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
37

38

39
40
@dataclass
class PromptReplacementDetails:
41
42
43
    """Details about the replacement token sequence or text."""

    full: PromptSeq
44
45
    """The full replacement."""

46
    features: PromptSeq
47
    """
48
49
50
    The part of the replacement that corresponds to feature placeholders;
    this will be replaced by the output of the vision encoder during model
    inference.
51
52
53
    """

    @staticmethod
54
    def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
55
56
57
        return PromptReplacementDetails(full=seq, features=seq)


58
59
60
61
62
63
64
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
"""
The replacement token sequence or text.

If only part of the replacement corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part.
"""
65
66


67
@dataclass
68
class PromptReplacement:
69
70
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=PromptReplacementDetails(
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=PromptReplacementDetails(
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
119
120
    """

121
    modality: str
122
    """The modality for which the replacement is made."""
123

124
    target: PromptSeq
125
    """The token sequence (or text) to find and replace."""
126

127
128
    replacement: Union[Callable[[int], PromptRepl],
                       PromptRepl] = field(repr=False)
129
    """
130
131
    Given the index of the processed item within :attr:`modality`,
    output the replacement token sequence (or text).
132

133
134
    For convenience, you can directly pass in the replacement token sequence
    (or text) instead of a function if it does not depend on the input.
135
136
    """

137
138
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
139
140
141
142
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
143
        )
144
145


146
147
148
149
150
151
152
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
153
154
155
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
156
157


158
159
160
161
162
163
164
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
165
166
167
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
168
169
170
171
172


class _HasModalityAttr(Protocol):
    modality: str

173

174
class _HasModalityProp(Protocol):
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
191
192
193
194
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
195
196
    tokenizer: AnyTokenizer = field(repr=False)

197
198
199
    _text: Optional[str]
    _token_ids: Optional[list[int]]

200
    @staticmethod
201
202
    def from_seq(
        tokenizer: AnyTokenizer,
203
        seq: PromptSeq,
204
    ) -> "_BoundPromptSequence":
205
206
207
208
209
210
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


233
234
235
236
237
238
@dataclass
class _BoundPromptReplacementGroup:
    full: _BoundPromptSequence
    features: _BoundPromptSequence


239
@dataclass
240
class BoundPromptReplacement:
241
242
243
244
245
    """
    A :class:`PromptReplacement` bound to a tokenizer to automatically
    convert :attr:`target` and the result of :meth:`get_replacement` between
    token sequence and text representations.
    """
246
    tokenizer: AnyTokenizer = field(repr=False)
247
248
    modality: str

249
250
251
    _target: PromptSeq
    _replacement: Union[Callable[[int], PromptRepl],
                        PromptRepl] = field(repr=False)
252

253
    def __post_init__(self) -> None:
254
        self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
255
256
257

    @property
    def target(self) -> _BoundPromptSequence:
258
        """The token sequence (or text) to find and replace."""
259
        return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
260

261
    def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
262
263
264
265
        """
        Given the index of the processed item within :attr:`modality`,
        output the replacement token sequence (or text).
        """
266
267
268
269
270
271
272
273
274
275
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

276
277
278
279
280
281
282
283
284
285
        if not isinstance(replacement, PromptReplacementDetails):
            replacement = PromptReplacementDetails.from_seq(replacement)

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
                                                   replacement.full)
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
                                                       replacement.features)
        bound_replacement = _BoundPromptReplacementGroup(
            full=bound_full,
            features=bound_features,
286
287
288
289
290
291
292
293
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


294
295
296
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
297
298


299
300
301
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
302
) -> Generator[_TokenMatch]:
303
304
305
306
307
308
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
309
    match_len = len(match_ids)
310

311
312
    if match_len == 0:
        return
313

314
315
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
316
        end_idx = start_idx + match_len
317

318
319
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
320
321
322
323
324

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
325
326


327
328
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
329
    prompt_repl: BoundPromptReplacement
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
351
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
352
353
354
355
356
357
358
359
360
361
362
363
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
364
class _PromptReplacementTextMatch(_PromptReplacementMatch):
365
366
367
368
369
370
371
372
373
374
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

375

376
@dataclass
377
class PlaceholderFeaturesInfo:
378
    modality: str
379
    item_idx: int
380
    start_idx: int
381
    tokens: list[int]
382
383
384

    @property
    def length(self) -> int:
385
        return len(self.tokens)
386
387
388
389
390
391

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
392
393
394
395


def find_token_matches(
    prompt: list[int],
396
    prompt_repls: Sequence[BoundPromptReplacement],
397
) -> list[_PromptReplacementTokenMatch]:
398
399
400
401
402
403
404
405
406
407
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
408
    prompt_repls: Sequence[BoundPromptReplacement],
409
) -> list[_PromptReplacementTextMatch]:
410
411
412
413
414
415
416
417
418
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
419
    prompt: PromptSeq,
420
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
421
) -> list[_PromptReplacementMatch]:
422
    """
423
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
424
    and sort them such that earlier matches take priority over later ones.
425
    """
426
427
    matches = [m for matches in mm_matches.values() for m in matches]

428
429
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
430

431
    for match in matches:
432
433
434
435
436
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
437

438
            seen_matches[idx] = match
439
440
441
442
443
444

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
445
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
446
    mm_item_counts: Mapping[str, int],
447
) -> list[_S]:
448
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
449
450
    out_seqs = list[_S]()
    prev_end_idx = 0
451
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
452

453
    for match in _resolve_matches(prompt, mm_matches):
454
455
456
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
457
        if item_idx >= mm_item_counts.get(modality, 0):
458
459
460
461
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
462

463
        repl_info = match.prompt_repl
464
465
466
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
467
            repl_seq = replacement.full.text
468
469
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
470
            repl_seq = replacement.full.token_ids
471
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
472
473
474
475
476
477
478
479
480
481
482

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
483
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
484
    mm_item_counts: Mapping[str, int],
485
) -> list[int]:
486
487
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
488
489
        return prompt

490
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
491
492

    return flatten_2d_lists(token_id_seqs)
493
494


495
496
def replace_text_matches(
    prompt: str,
497
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
498
    mm_item_counts: Mapping[str, int],
499
) -> str:
500
501
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
502
        return prompt
503

504
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
505
506

    return "".join(texts)
507
508


509
510
def _iter_placeholders(
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
511
    prompt: list[int],
512
    mm_item_counts: Mapping[str, int],
513
) -> Iterable[PlaceholderFeaturesInfo]:
514
515
516
517
518
519
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_repls` takes priority.
520

521
522
    Note that empty matches are ignored.
    """
523
    prompt_len = len(prompt)
524
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
525
526
527
528
529

    start_idx = 0
    while start_idx < prompt_len:
        found = False

530
531
532
        for modality, modality_repls in mm_prompt_repls.items():
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
533
                continue
534

535
536
            for repl_info in modality_repls:
                replacement = repl_info.get_replacement(item_idx)
537
538
539
                repl_tokens_full = replacement.full.token_ids
                repl_len_full = len(repl_tokens_full)
                end_idx_full = start_idx + repl_len_full
540

541
                if repl_len_full == 0 or end_idx_full > prompt_len:
542
543
                    continue

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
                if prompt[start_idx:end_idx_full] == repl_tokens_full:
                    repl_tokens_feat = replacement.features.token_ids

                    try:
                        match = next(
                            iter_token_matches(repl_tokens_full,
                                               repl_tokens_feat))
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
                            tokens=repl_tokens_feat,
                        )
                    except StopIteration:
                        raise AssertionError(
                            f"{repl_tokens_feat=} should be a "
                            f"subsequence of {repl_tokens_full=}") from None
561

562
                    # Exclude overlapping matches
563
                    start_idx = end_idx_full
564
565
566
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
567

568
569
            if found:
                break  # Go back to the outer while loop
570
571
572

        if not found:
            start_idx += 1
573
574


575
def find_mm_placeholders(
576
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
577
578
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
579
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
580
581
582
583
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


584
585
586
587
588
589
590
591
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

592
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
610
    ) -> Optional[MultiModalKwargsItem]:
611
612
613
614
615
616
617
618
619
620
621
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

622
623
624
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
625
626
627
628
629
630
631
632
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
633
        output_kwargs: MultiModalKwargsItem,
634
635
636
637
638
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
639
640
641
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
642
        self._cache.put(cache_key, output_kwargs)
643
644


645
class BaseProcessingInfo:
646
    """Base class to provide the information necessary for data processing."""
647

648
649
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
650

651
652
653
654
655
656
657
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
658
659
        return self.ctx.tokenizer

660
    def get_hf_config(self) -> PretrainedConfig:
661
662
        return self.ctx.get_hf_config()

663
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
664
665
666
667
668
669
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

670
671
672
673
674
675
676
677
678
679
680
681
682
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
683
684
685
686
687
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
688
689
690
691
692
693
694
695
696
697
698
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
699

700
701

class BaseMultiModalProcessor(ABC, Generic[_I]):
702
    """
703
    Abstract base class to process multi-modal inputs to be used in vLLM.
704
705

    Not to be confused with :class:`transformers.ProcessorMixin`.
706
707
    """

708
    def __init__(self,
709
710
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
711
712
713
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
714
715
        super().__init__()

716
717
        self.info = info
        self.dummy_inputs = dummy_inputs
718
719
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
720

721
722
        self.data_parser = self._get_data_parser()

723
    def __call__(
724
        self,
725
726
        prompt: str,
        mm_data: MultiModalDataDict,
727
        hf_processor_mm_kwargs: Mapping[str, object],
728
    ) -> MultiModalInputs:
729
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
730

731
732
    def _get_data_parser(self) -> MultiModalDataParser:
        """
733
        Construct a parser to preprocess multi-modal data items
734
735
736
737
738
739
740
741
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
742
743
744
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
745
746
747
748
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
749
        mm_items = self.data_parser.parse_mm_data(mm_data)
750

751
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
752
753
754
755
756
757
758
759
760
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
761

762
763
764
765
766
767
768
769
770
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

771
772
    @abstractmethod
    def _get_prompt_replacements(
773
        self,
774
        mm_items: MultiModalDataItems,
775
776
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
777
778
779
780
781
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

782
783
784
785
786
787
788
789
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
790
791
        """
        raise NotImplementedError
792

793
    def _find_mm_placeholders(
794
        self,
795
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
796
        new_token_ids: list[int],
797
        mm_item_counts: Mapping[str, int],
798
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
799
800
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
801

802
    def _get_hf_mm_data(
803
        self,
804
        mm_items: MultiModalDataItems,
805
806
807
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
808

809
810
811
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
812

813
814
        return processor_data, passthrough_data

815
816
817
    def _call_hf_processor(
        self,
        prompt: str,
818
819
820
821
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
822
    ) -> BatchFeature:
823
824
825
826
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
827
828
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
829
830
            dict(text=prompt, **mm_data),
            mm_kwargs,
831
832
        )

833
    def _apply_hf_processor_text_mm(
834
        self,
835
        prompt_text: str,
836
        mm_items: MultiModalDataItems,
837
838
839
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
840
841
        Apply the HF processor on the prompt text and multi-modal data
        together.
842
843
844
845
846
847
848
849
850
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
851

852
        prompt_ids, = processed_data.pop("input_ids").tolist()
853

854
855
856
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
857
        )
858

859
860
        return prompt_ids, mm_kwargs

861
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
862
        """
863
        Apply the HF processor on the prompt text only.
864

865
866
867
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
868
        """
869
        prompt_ids, _ = self._apply_hf_processor_text_mm(
870
871
872
873
874
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

906
907
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
908
            mm_counts,
909
        )
910

911
        _, mm_kwargs = self._apply_hf_processor_text_mm(
912
            prompt_text=dummy_inputs.prompt_text,
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_replacement: bool,
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the prompt text and multi-modal data.

        Note:
            If :code:`enable_hf_prompt_replacement=False`, the prompt should
            correspond to the multi-modal items.
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_replacement:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

        mm_missing_kwargs = self._apply_hf_processor_mm_only(
            mm_items=mm_items,
948
949
950
951
952
953
954
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
955
        prompt: Union[str, list[int]],
956
957
958
959
960
961
962
963
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
964
        model_id = self.info.model_id
965

966
967
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
968
969
            return self._apply_hf_processor_main(
                prompt=prompt,
970
971
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
972
                enable_hf_prompt_replacement=True,
973
974
            )

975
        mm_maybe_cached_kw_items = {
976
977
978
979
980
981
982
983
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
984
985
986
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
987
988
989
990
991
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
992
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
993

994
995
996
997
998
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
        # so we need to pass `enable_hf_prompt_replacement=False`
        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_missing_data_items,
999
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1000
            enable_hf_prompt_replacement=False,
1001
1002
1003
1004
1005
1006
1007
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1008
1009
1010
1011
1012
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1013
1014
1015
1016
1017
1018
1019
1020
1021
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1022
                        kw_item,
1023
1024
1025
1026
                    )

                    mm_missing_next_idx[modality] += 1

1027
                merged_kw_items.append(kw_item)
1028
1029

        if self.enable_sanity_checks:
1030
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1031
1032
1033
1034
1035
1036
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1037
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1038
1039

        return prompt_ids, mm_kwargs
1040

1041
    def _bind_and_group_repls(
1042
        self,
1043
        prompt_repls: list[PromptReplacement],
1044
1045
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
1046

1047
1048
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
1049

1050
1051
1052
1053
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
1054
1055
        detect that HF has performed processing via
        :meth:`_find_placeholders_by_modality`.
1056

1057
1058
        This is useful in cases where :meth:`_find_placeholders_by_modality`
        cannot be reliably used to detect whether HF has performed processing.
1059
1060
1061
        """
        return False

1062
1063
1064
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
1065
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
1066
        mm_item_counts: Mapping[str, int],
1067
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1068
        tokenizer = self.info.get_tokenizer()
1069

1070
1071
1072
1073
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
1074
1075
        mm_match_counts = {
            modality: len(matches)
1076
            for modality, matches in mm_token_matches.items()
1077
        }
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
1090
1091
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1092
1093
1094
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
1095
                mm_token_matches,
1096
                mm_item_counts,
1097
1098
            )

1099
1100
1101
1102
1103
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
1104
        else:
1105
            text = decode_tokens(tokenizer, token_ids)
1106

1107
1108
1109
1110
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
1111
1112
            text = replace_text_matches(
                text,
1113
                mm_text_matches,
1114
                mm_item_counts,
1115
1116
            )

1117
1118
1119
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
1130
1131

        return token_ids, text, placeholders
1132

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1156
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        mm_item_counts: Mapping[str, int],
        *,
        allow_missing: bool = False,
    ) -> Mapping[str, int]:
        missing_repl_counts = dict[str, int]()

        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count and not allow_missing:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
                    f"corresponding to {item_count} {modality} items, but only "
                    f"found {len(placeholders)} prompt replacements! Either "
                    "the prompt text has missing/incorrect tokens for "
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

            missing_repl_counts[modality] = item_count - len(placeholders)

        return missing_repl_counts

1181
1182
    def apply(
        self,
1183
        prompt: Union[str, list[int]],
1184
        mm_data: MultiModalDataDict,
1185
        hf_processor_mm_kwargs: Mapping[str, object],
1186
    ) -> MultiModalInputs:
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1200
        mm_items = self._to_mm_items(mm_data)
1201

1202
1203
1204
1205
1206
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1207
            model_id = self.info.model_id
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1220
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
1221
            prompt,
1222
1223
1224
            mm_items,
            hf_processor_mm_kwargs,
        )
1225

1226
1227
1228
1229
1230
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1231
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1232

1233
        mm_item_counts = mm_items.get_all_counts()
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        hf_mm_placeholders = self._find_mm_placeholders(
            mm_prompt_repls,
            prompt_ids,
            mm_item_counts,
        )

        if self._always_apply_prompt_replacements():
            mm_missing_repl_counts = mm_item_counts
            mm_missing_repls = dict(mm_prompt_repls)
        else:
            mm_missing_repl_counts = self._validate_mm_placeholders(
                hf_mm_placeholders,
                mm_item_counts,
                allow_missing=True,
            )

1252
            mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1253
1254
1255
1256
1257
1258
1259
1260
            for modality, missing_repl_count in mm_missing_repl_counts.items():
                if missing_repl_count == 0:
                    mm_missing_repls[modality] = []
                elif missing_repl_count == mm_item_counts.get(modality, 0):
                    mm_missing_repls[modality] = mm_prompt_repls[modality]
                else:
                    raise ValueError("Partial prompt replacement within "
                                     f"{modality=} is not supported")
1261

1262
1263
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
1264
        if all(len(repls) == 0 for repls in mm_missing_repls.values()):
1265
            tokenizer = self.info.get_tokenizer()
1266
            prompt = decode_tokens(tokenizer, prompt_ids)
1267
            mm_placeholders = hf_mm_placeholders
1268
1269
1270
        else:
            (
                prompt_ids,
1271
                prompt,
1272
                missing_mm_placeholders,
1273
1274
            ) = self._apply_prompt_replacements(
                prompt_ids,
1275
1276
                mm_missing_repls,
                mm_missing_repl_counts,
1277
1278
            )

1279
1280
1281
1282
1283
1284
1285
1286
            mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1287

1288
        return MultiModalInputs(
1289
            type="multimodal",
1290
            prompt=prompt,
1291
            prompt_token_ids=prompt_ids,
1292
            mm_kwargs=mm_kwargs,
1293
            mm_hashes=mm_hashes,
1294
            mm_placeholders=mm_placeholder_ranges,
1295
        )
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the encoder."""
        raise NotImplementedError

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
        )

        # We assumed the decoder prompt text is copied from
        # the original encoder prompt without extra process
        tokenizer = self.info.get_tokenizer()
        if isinstance(prompt, str):
            decoder_prompt = prompt
            decoder_prompt_ids = encode_tokens(tokenizer,
                                               prompt,
                                               add_special_tokens=False)
        else:
            decoder_prompt = decode_tokens(tokenizer, prompt)
            decoder_prompt_ids = prompt

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs