processing.py 44.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import re
from abc import ABC, abstractmethod
5
from collections import defaultdict
6
7
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
8
from dataclasses import dataclass, field
9
from functools import lru_cache
10
11
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
12

13
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
14

15
16
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
17
from vllm.logger import init_logger
18
19
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
20
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
21

22
from .hasher import MultiModalHasher
23
24
25
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
                     MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
26
27
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
                    MultiModalDataParser)
28
29
30

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
31

32
logger = init_logger(__name__)
33
34

_S = TypeVar("_S", str, list[int])
35
36
37

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
38

39

40
41
@dataclass
class PromptReplacementDetails:
42
43
44
    """Details about the replacement token sequence or text."""

    full: PromptSeq
45
46
    """The full replacement."""

47
    features: PromptSeq
48
    """
49
50
51
    The part of the replacement that corresponds to feature placeholders;
    this will be replaced by the output of the vision encoder during model
    inference.
52
53
54
    """

    @staticmethod
55
    def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
56
57
58
        return PromptReplacementDetails(full=seq, features=seq)


59
60
61
62
63
64
65
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
"""
The replacement token sequence or text.

If only part of the replacement corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part.
"""
66
67


68
@dataclass
69
class PromptReplacement:
70
71
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    Example:

        For each image, replace one ``<image>`` input placeholder in the prompt
        with a number of ``<image>`` feature placeholders
        equal to the feature size of the vision encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement="<image>" * image_feature_size,
            )

        As above, but further pad the feature placeholders with ``<image_bos>``
        and `<image_eos>``, which are not supposed to be passed to the vision
        encoder:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=PromptReplacementDetails(
                    full="".join([
                        "<image_bos>",
                        "<image>" * image_feature_size,
                        "<image_eos>",
                    ]),
                    features="<image>" * image_feature_size,
                ),
            )

        To avoid unnecessary tokenization during prompt replacement,
        we recommended passing token sequences instead of text:

        .. code-block:: python

            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=PromptReplacementDetails(
                    full=([image_bos_id] + [image_token_id] * image_feature_size
                          + [image_eos_id]),
                    features=[image_token_id] * image_feature_size,
                ),
            )
120
121
    """

122
    modality: str
123
    """The modality for which the replacement is made."""
124

125
    target: PromptSeq
126
    """The token sequence (or text) to find and replace."""
127

128
129
    replacement: Union[Callable[[int], PromptRepl],
                       PromptRepl] = field(repr=False)
130
    """
131
132
    Given the index of the processed item within :attr:`modality`,
    output the replacement token sequence (or text).
133

134
135
    For convenience, you can directly pass in the replacement token sequence
    (or text) instead of a function if it does not depend on the input.
136
137
    """

138
139
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
140
141
142
143
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
144
        )
145
146


147
148
149
150
151
152
153
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
154
155
156
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
157
158


159
160
161
162
163
164
165
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
166
167
168
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
169
170
171
172
173


class _HasModalityAttr(Protocol):
    modality: str

174

175
class _HasModalityProp(Protocol):
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
192
193
194
195
    """
    A :data:`_PromptSeq` bound to a tokenizer to automatically
    convert between token sequence and text representations.
    """
196
197
    tokenizer: AnyTokenizer = field(repr=False)

198
199
200
    _text: Optional[str]
    _token_ids: Optional[list[int]]

201
    @staticmethod
202
203
    def from_seq(
        tokenizer: AnyTokenizer,
204
        seq: PromptSeq,
205
    ) -> "_BoundPromptSequence":
206
207
208
209
210
211
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


234
235
236
237
238
239
@dataclass
class _BoundPromptReplacementGroup:
    full: _BoundPromptSequence
    features: _BoundPromptSequence


240
@dataclass
241
class BoundPromptReplacement:
242
243
244
245
246
    """
    A :class:`PromptReplacement` bound to a tokenizer to automatically
    convert :attr:`target` and the result of :meth:`get_replacement` between
    token sequence and text representations.
    """
247
    tokenizer: AnyTokenizer = field(repr=False)
248
249
    modality: str

250
251
252
    _target: PromptSeq
    _replacement: Union[Callable[[int], PromptRepl],
                        PromptRepl] = field(repr=False)
253

254
    def __post_init__(self) -> None:
255
        self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
256
257
258

    @property
    def target(self) -> _BoundPromptSequence:
259
        """The token sequence (or text) to find and replace."""
260
        return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
261

262
    def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
263
264
265
266
        """
        Given the index of the processed item within :attr:`modality`,
        output the replacement token sequence (or text).
        """
267
268
269
270
271
272
273
274
275
276
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

277
278
279
280
281
282
283
284
285
286
        if not isinstance(replacement, PromptReplacementDetails):
            replacement = PromptReplacementDetails.from_seq(replacement)

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
                                                   replacement.full)
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
                                                       replacement.features)
        bound_replacement = _BoundPromptReplacementGroup(
            full=bound_full,
            features=bound_features,
287
288
289
290
291
292
293
294
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


295
296
297
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
298
299


300
301
302
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
303
) -> Generator[_TokenMatch]:
304
305
306
307
308
309
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
310
    match_len = len(match_ids)
311

312
313
    if match_len == 0:
        return
314

315
316
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
317
        end_idx = start_idx + match_len
318

319
320
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
321
322
323
324
325

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
326
327


328
329
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
330
    prompt_repl: BoundPromptReplacement
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
352
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
353
354
355
356
357
358
359
360
361
362
363
364
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
365
class _PromptReplacementTextMatch(_PromptReplacementMatch):
366
367
368
369
370
371
372
373
374
375
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

376

377
@dataclass
378
class PlaceholderFeaturesInfo:
379
    modality: str
380
    item_idx: int
381
    start_idx: int
382
    tokens: list[int]
383
384
385

    @property
    def length(self) -> int:
386
        return len(self.tokens)
387
388
389
390
391
392

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
393
394
395
396


def find_token_matches(
    prompt: list[int],
397
    prompt_repls: Sequence[BoundPromptReplacement],
398
) -> list[_PromptReplacementTokenMatch]:
399
400
401
402
403
404
405
406
407
408
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
409
    prompt_repls: Sequence[BoundPromptReplacement],
410
) -> list[_PromptReplacementTextMatch]:
411
412
413
414
415
416
417
418
419
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
420
    prompt: PromptSeq,
421
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
422
) -> list[_PromptReplacementMatch]:
423
    """
424
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
425
    and sort them such that earlier matches take priority over later ones.
426
    """
427
428
    matches = [m for matches in mm_matches.values() for m in matches]

429
430
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
431

432
    for match in matches:
433
434
435
436
437
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
438

439
            seen_matches[idx] = match
440
441
442
443
444
445

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
446
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
447
    mm_item_counts: Mapping[str, int],
448
) -> list[_S]:
449
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
450
451
    out_seqs = list[_S]()
    prev_end_idx = 0
452
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
453

454
    for match in _resolve_matches(prompt, mm_matches):
455
456
457
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
458
        if item_idx >= mm_item_counts.get(modality, 0):
459
460
461
462
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
463

464
        repl_info = match.prompt_repl
465
466
467
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
468
            repl_seq = replacement.full.text
469
470
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
471
            repl_seq = replacement.full.token_ids
472
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
473
474
475
476
477
478
479
480
481
482
483

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
484
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
485
    mm_item_counts: Mapping[str, int],
486
) -> list[int]:
487
488
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
489
490
        return prompt

491
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
492
493

    return flatten_2d_lists(token_id_seqs)
494
495


496
497
def replace_text_matches(
    prompt: str,
498
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
499
    mm_item_counts: Mapping[str, int],
500
) -> str:
501
502
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
503
        return prompt
504

505
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
506
507

    return "".join(texts)
508
509


510
511
def _iter_placeholders(
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
512
    prompt: list[int],
513
    mm_item_counts: Mapping[str, int],
514
) -> Iterable[PlaceholderFeaturesInfo]:
515
516
517
518
519
520
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_repls` takes priority.
521

522
523
    Note that empty matches are ignored.
    """
524
    prompt_len = len(prompt)
525
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
526
527
528
529
530

    start_idx = 0
    while start_idx < prompt_len:
        found = False

531
532
533
        for modality, modality_repls in mm_prompt_repls.items():
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
534
                continue
535

536
537
            for repl_info in modality_repls:
                replacement = repl_info.get_replacement(item_idx)
538
539
540
                repl_tokens_full = replacement.full.token_ids
                repl_len_full = len(repl_tokens_full)
                end_idx_full = start_idx + repl_len_full
541

542
                if repl_len_full == 0 or end_idx_full > prompt_len:
543
544
                    continue

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
                if prompt[start_idx:end_idx_full] == repl_tokens_full:
                    repl_tokens_feat = replacement.features.token_ids

                    try:
                        match = next(
                            iter_token_matches(repl_tokens_full,
                                               repl_tokens_feat))
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
                            tokens=repl_tokens_feat,
                        )
                    except StopIteration:
                        raise AssertionError(
                            f"{repl_tokens_feat=} should be a "
                            f"subsequence of {repl_tokens_full=}") from None
562

563
                    # Exclude overlapping matches
564
                    start_idx = end_idx_full
565
566
567
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
568

569
570
            if found:
                break  # Go back to the outer while loop
571
572
573

        if not found:
            start_idx += 1
574
575


576
def find_mm_placeholders(
577
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
578
579
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
580
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
581
582
583
584
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


585
586
587
588
589
590
591
592
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

593
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
611
    ) -> Optional[MultiModalKwargsItem]:
612
613
614
615
616
617
618
619
620
621
622
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

623
624
625
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
626
627
628
629
630
631
632
633
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
634
        output_kwargs: MultiModalKwargsItem,
635
636
637
638
639
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
640
641
642
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
643
        self._cache.put(cache_key, output_kwargs)
644
645


646
class BaseProcessingInfo:
647
    """Base class to provide the information necessary for data processing."""
648

649
650
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
651

652
653
654
655
656
657
658
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
659
660
        return self.ctx.tokenizer

661
    def get_hf_config(self) -> PretrainedConfig:
662
663
        return self.ctx.get_hf_config()

664
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
665
666
667
668
669
670
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

671
672
673
674
675
676
677
678
679
680
681
682
683
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
684
685
686
687
688
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
689
690
691
692
693
694
695
696
697
698
699
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
700

701
702

class BaseMultiModalProcessor(ABC, Generic[_I]):
703
    """
704
    Abstract base class to process multi-modal inputs to be used in vLLM.
705
706

    Not to be confused with :class:`transformers.ProcessorMixin`.
707
708
    """

709
    def __init__(self,
710
711
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
712
713
714
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
715
716
        super().__init__()

717
718
        self.info = info
        self.dummy_inputs = dummy_inputs
719
720
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
721

722
723
        self.data_parser = self._get_data_parser()

724
    def __call__(
725
        self,
726
727
        prompt: str,
        mm_data: MultiModalDataDict,
728
        hf_processor_mm_kwargs: Mapping[str, object],
729
    ) -> MultiModalInputs:
730
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
731

732
733
    def _get_data_parser(self) -> MultiModalDataParser:
        """
734
        Construct a parser to preprocess multi-modal data items
735
736
737
738
739
740
741
742
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
743
744
745
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
746
747
748
749
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
750
        mm_items = self.data_parser.parse_mm_data(mm_data)
751

752
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
753
754
755
756
757
758
759
760
761
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
762

763
764
765
766
767
768
769
770
771
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

772
773
    @abstractmethod
    def _get_prompt_replacements(
774
        self,
775
        mm_items: MultiModalDataItems,
776
777
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
778
779
780
781
782
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

783
784
785
786
787
788
789
790
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
791
792
        """
        raise NotImplementedError
793

794
    def _find_mm_placeholders(
795
        self,
796
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
797
        new_token_ids: list[int],
798
        mm_item_counts: Mapping[str, int],
799
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
800
801
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
802

803
    def _get_hf_mm_data(
804
        self,
805
        mm_items: MultiModalDataItems,
806
807
808
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
809

810
811
812
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
813

814
815
        return processor_data, passthrough_data

816
817
818
    def _call_hf_processor(
        self,
        prompt: str,
819
820
821
822
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
823
    ) -> BatchFeature:
824
825
826
827
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
828
829
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
830
831
            dict(text=prompt, **mm_data),
            mm_kwargs,
832
833
        )

834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    def _hf_processor_applies_repl(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        """
        Return whether the HF processor applies prompt replacements.

        For most HF processors, this should be :code:`True` when multi-modal
        data items are passed, but :code:`False` when multi-modal embeddings
        are passed.
        """
        return not any(
            isinstance(items, (EmbeddingItems, DictEmbeddingItems))
            for items in mm_items.values())

851
    def _apply_hf_processor_text_mm(
852
        self,
853
        prompt_text: str,
854
        mm_items: MultiModalDataItems,
855
        hf_processor_mm_kwargs: Mapping[str, object],
856
    ) -> tuple[list[int], MultiModalKwargs, bool]:
857
        """
858
859
        Apply the HF processor on the prompt text and multi-modal data
        together.
860
861

        In addition, return whether prompt replacements have been applied.
862
863
864
865
866
867
868
869
870
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
871

872
        prompt_ids, = processed_data.pop("input_ids").tolist()
873

874
875
876
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
877
        )
878

879
880
881
882
883
884
885
        is_repl_applied = self._hf_processor_applies_repl(
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_kwargs, is_repl_applied
886

887
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
888
        """
889
        Apply the HF processor on the prompt text only.
890

891
892
893
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
894
        """
895
        prompt_ids, _, _ = self._apply_hf_processor_text_mm(
896
897
898
899
900
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

932
933
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
934
            mm_counts,
935
        )
936

937
        _, mm_kwargs, _ = self._apply_hf_processor_text_mm(
938
            prompt_text=dummy_inputs.prompt_text,
939
940
941
942
943
944
945
946
947
948
949
950
951
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_replacement: bool,
952
    ) -> tuple[list[int], MultiModalKwargs, bool]:
953
954
955
        """
        Apply the HF processor on the prompt text and multi-modal data.

956
957
958
        In addition, return whether prompt replacements have been applied
        (for most HF processors, this should be :code:`True`).

959
        Note:
960
961
962
            If :code:`enable_hf_prompt_replacement=False`, we use HF processor
            to perform prompt replacement if available; HF processor requires
            that the prompt corresponds to multi-modal items.
963
964
965
966
967
968
969
970
971
972
973
974
975
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_replacement:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

976
        mm_kwargs = self._apply_hf_processor_mm_only(
977
            mm_items=mm_items,
978
979
980
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

981
        return prompt_ids, mm_kwargs, False
982
983
984

    def _cached_apply_hf_processor(
        self,
985
        prompt: Union[str, list[int]],
986
987
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
988
    ) -> tuple[list[int], MultiModalKwargs, bool]:
989
990
991
992
993
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
994
        model_id = self.info.model_id
995

996
997
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
998
999
            return self._apply_hf_processor_main(
                prompt=prompt,
1000
1001
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1002
                enable_hf_prompt_replacement=True,
1003
1004
            )

1005
        mm_maybe_cached_kw_items = {
1006
1007
1008
1009
1010
1011
1012
1013
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
1014
1015
1016
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
1017
1018
1019
1020
1021
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
1022
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
1023

1024
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
1025
1026
1027
1028
1029
1030
1031
        # so we can't apply prompt replacements until the new multimodal
        # items are combined with the cached multimodal items
        (
            prompt_ids,
            mm_missing_kwargs,
            is_repl_applied,
        ) = self._apply_hf_processor_main(
1032
1033
            prompt=prompt,
            mm_items=mm_missing_data_items,
1034
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1035
            enable_hf_prompt_replacement=False,
1036
1037
1038
1039
1040
1041
1042
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

1043
1044
1045
1046
1047
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
1048
1049
1050
1051
1052
1053
1054
1055
1056
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
1057
                        kw_item,
1058
1059
1060
1061
                    )

                    mm_missing_next_idx[modality] += 1

1062
                merged_kw_items.append(kw_item)
1063
1064

        if self.enable_sanity_checks:
1065
            mm_missing_counts = mm_missing_data_items.get_all_counts()
1066
1067
1068
1069
1070
1071
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

1072
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
1073

1074
        return prompt_ids, mm_kwargs, is_repl_applied
1075

1076
    def _bind_and_group_repls(
1077
        self,
1078
        prompt_repls: list[PromptReplacement],
1079
1080
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
1081

1082
1083
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
1084

1085
1086
1087
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
1088
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
1089
        mm_item_counts: Mapping[str, int],
1090
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1091
        tokenizer = self.info.get_tokenizer()
1092

1093
1094
1095
1096
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
1097
1098
        mm_match_counts = {
            modality: len(matches)
1099
            for modality, matches in mm_token_matches.items()
1100
        }
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
1113
1114
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1115
1116
1117
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
1118
                mm_token_matches,
1119
                mm_item_counts,
1120
1121
            )

1122
1123
1124
1125
1126
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
1127
        else:
1128
            text = decode_tokens(tokenizer, token_ids)
1129

1130
1131
1132
1133
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
1134
1135
            text = replace_text_matches(
                text,
1136
                mm_text_matches,
1137
                mm_item_counts,
1138
1139
            )

1140
1141
1142
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
1153
1154

        return token_ids, text, placeholders
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1179
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1180
        mm_item_counts: Mapping[str, int],
1181
    ) -> None:
1182
1183
1184
        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

1185
            if len(placeholders) != item_count:
1186
1187
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
1188
1189
1190
                    f"corresponding to {item_count} {modality} items, but "
                    f"instead found {len(placeholders)} prompt replacements! "
                    "Either the prompt text has missing/incorrect tokens for "
1191
1192
1193
1194
1195
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

1196
1197
    def apply(
        self,
1198
        prompt: Union[str, list[int]],
1199
        mm_data: MultiModalDataDict,
1200
        hf_processor_mm_kwargs: Mapping[str, object],
1201
    ) -> MultiModalInputs:
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1215
        mm_items = self._to_mm_items(mm_data)
1216

1217
1218
1219
1220
1221
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1222
            model_id = self.info.model_id
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1235
1236
1237
1238
1239
        (
            prompt_ids,
            mm_kwargs,
            is_repl_applied,
        ) = self._cached_apply_hf_processor(
1240
            prompt,
1241
1242
1243
            mm_items,
            hf_processor_mm_kwargs,
        )
1244

1245
1246
1247
1248
1249
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1250
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1251

1252
        mm_item_counts = mm_items.get_all_counts()
1253
1254
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

1255
1256
1257
1258
        if is_repl_applied:
            mm_placeholders = self._find_mm_placeholders(
                mm_prompt_repls,
                prompt_ids,
1259
1260
                mm_item_counts,
            )
1261
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1262

1263
            tokenizer = self.info.get_tokenizer()
1264
            prompt = decode_tokens(tokenizer, prompt_ids)
1265
1266
1267
        else:
            (
                prompt_ids,
1268
                prompt,
1269
                mm_placeholders,
1270
1271
            ) = self._apply_prompt_replacements(
                prompt_ids,
1272
1273
                mm_prompt_repls,
                mm_item_counts,
1274
            )
1275
            self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
1276
1277
1278
1279
1280

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1281

1282
        return MultiModalInputs(
1283
            type="multimodal",
1284
            prompt=prompt,
1285
            prompt_token_ids=prompt_ids,
1286
            mm_kwargs=mm_kwargs,
1287
            mm_hashes=mm_hashes,
1288
            mm_placeholders=mm_placeholder_ranges,
1289
        )
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299


class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):

    @abstractmethod
    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
1300
1301
1302
1303
        """
        Create input prompt for the encoder. HF processor will be applied on 
        this prompt during profiling and generation.
        """
1304
1305
        raise NotImplementedError

1306
1307
1308
1309
1310
1311
1312
1313
    def create_decoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        """Create input prompt for the decoder."""
        return prompt

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalEncDecInputs:
        """
        Process multi-modal inputs to be used in vLLM.
        The main processing steps are modified to fit encoder-decoder model:
        1. Create encoder prompt from input prompt text.
        2. Apply the HF processor on encoder prompt.
        3. Copy the input prompt text as decoder prompt inputs.
        """
        encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
        encoder_inputs = super().apply(
            encoder_prompt,
            mm_data,
            hf_processor_mm_kwargs,
        )

        tokenizer = self.info.get_tokenizer()
1335
1336
        decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
        if isinstance(decoder_prompt, str):
1337
            decoder_prompt_ids = encode_tokens(tokenizer,
1338
                                               decoder_prompt,
1339
1340
                                               add_special_tokens=False)
        else:
1341
1342
            decoder_prompt_ids = decoder_prompt
            decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352

        mm_inputs = MultiModalEncDecInputs(
            encoder_prompt=encoder_inputs["prompt"],
            encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
            **encoder_inputs)
        mm_inputs.update({
            "prompt": decoder_prompt,
            "prompt_token_ids": decoder_prompt_ids
        })
        return mm_inputs