processing.py 40.2 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import defaultdict
4
5
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
6
from dataclasses import dataclass, field
7
from functools import lru_cache
8
9
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
10

11
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
12

13
14
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
15
from vllm.logger import init_logger
16
17
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
18
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
19

20
from .hasher import MultiModalHasher
21
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
22
23
                     MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
                     PlaceholderRange)
24
from .parse import MultiModalDataItems, MultiModalDataParser
25
26
27

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
28

29
logger = init_logger(__name__)
30
31

_S = TypeVar("_S", str, list[int])
32
_PromptSeq = Union[str, list[int]]
33

34

35
36
37
38
39
40
41
42
43
44
45
@dataclass
class PromptReplacementDetails:
    full: _PromptSeq
    """The full replacement."""

    features: _PromptSeq
    """
    The part of the replacement that corresponds to placeholder feature tokens.
    """

    @staticmethod
46
    def from_seq(seq: _PromptSeq) -> "PromptReplacementDetails":
47
48
49
50
51
52
        return PromptReplacementDetails(full=seq, features=seq)


_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]


53
@dataclass
54
class PromptReplacement:
55
56
57
58
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
    """

59
    modality: str
60
    """The modality for which the replacement is made."""
61

62
    target: _PromptSeq
63
    """The token sequence (or text) to find and replace."""
64

65
66
    replacement: Union[Callable[[int], _PromptRepl],
                       _PromptRepl] = field(repr=False)
67
    """
68
69
    Given the index of the processed item within :attr:`modality`,
    output the replacement token sequence (or text).
70

71
72
    For convenience, you can directly pass in the replacement token sequence
    (or text) instead of a function if it does not depend on the input.
73
74
    """

75
76
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
77
78
79
80
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
81
        )
82
83


84
85
86
87
88
89
90
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
91
92
93
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
94
95


96
97
98
99
100
101
102
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
103
104
105
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
106
107
108
109
110


class _HasModalityAttr(Protocol):
    modality: str

111

112
class _HasModalityProp(Protocol):
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
129
130
    tokenizer: AnyTokenizer = field(repr=False)

131
132
133
    _text: Optional[str]
    _token_ids: Optional[list[int]]

134
    @staticmethod
135
136
137
138
    def from_seq(
        tokenizer: AnyTokenizer,
        seq: _PromptSeq,
    ) -> "_BoundPromptSequence":
139
140
141
142
143
144
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


167
168
169
170
171
172
@dataclass
class _BoundPromptReplacementGroup:
    full: _BoundPromptSequence
    features: _BoundPromptSequence


173
@dataclass
174
class BoundPromptReplacement:
175
176
177
178
179
    """
    A :class:`PromptReplacement` bound to a tokenizer to automatically
    convert :attr:`target` and the result of :meth:`get_replacement` between
    token sequence and text representations.
    """
180
    tokenizer: AnyTokenizer = field(repr=False)
181
182
    modality: str

183
    _target: _PromptSeq
184
185
    _replacement: Union[Callable[[int], _PromptRepl],
                        _PromptRepl] = field(repr=False)
186

187
    def __post_init__(self) -> None:
188
        self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
189
190
191

    @property
    def target(self) -> _BoundPromptSequence:
192
        """The token sequence (or text) to find and replace."""
193
        return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
194

195
    def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
196
197
198
199
        """
        Given the index of the processed item within :attr:`modality`,
        output the replacement token sequence (or text).
        """
200
201
202
203
204
205
206
207
208
209
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

210
211
212
213
214
215
216
217
218
219
        if not isinstance(replacement, PromptReplacementDetails):
            replacement = PromptReplacementDetails.from_seq(replacement)

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
                                                   replacement.full)
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
                                                       replacement.features)
        bound_replacement = _BoundPromptReplacementGroup(
            full=bound_full,
            features=bound_features,
220
221
222
223
224
225
226
227
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


228
229
230
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
231
232


233
234
235
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
236
) -> Generator[_TokenMatch]:
237
238
239
240
241
242
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
243
    match_len = len(match_ids)
244

245
246
    if match_len == 0:
        return
247

248
249
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
250
        end_idx = start_idx + match_len
251

252
253
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
254
255
256
257
258

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
259
260


261
262
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
263
    prompt_repl: BoundPromptReplacement
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
285
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
286
287
288
289
290
291
292
293
294
295
296
297
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
298
class _PromptReplacementTextMatch(_PromptReplacementMatch):
299
300
301
302
303
304
305
306
307
308
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

309

310
@dataclass
311
class PlaceholderFeaturesInfo:
312
    modality: str
313
    item_idx: int
314
    start_idx: int
315
    tokens: list[int]
316
317
318

    @property
    def length(self) -> int:
319
        return len(self.tokens)
320
321
322
323
324
325

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
326
327
328
329


def find_token_matches(
    prompt: list[int],
330
    prompt_repls: Sequence[BoundPromptReplacement],
331
) -> list[_PromptReplacementTokenMatch]:
332
333
334
335
336
337
338
339
340
341
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
342
    prompt_repls: Sequence[BoundPromptReplacement],
343
) -> list[_PromptReplacementTextMatch]:
344
345
346
347
348
349
350
351
352
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
353
    prompt: _PromptSeq,
354
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
355
) -> list[_PromptReplacementMatch]:
356
    """
357
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
358
    and sort them such that earlier matches take priority over later ones.
359
    """
360
361
    matches = [m for matches in mm_matches.values() for m in matches]

362
363
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
364

365
    for match in matches:
366
367
368
369
370
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
371

372
            seen_matches[idx] = match
373
374
375
376
377
378

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
379
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
380
    mm_item_counts: Mapping[str, int],
381
) -> list[_S]:
382
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
383
384
    out_seqs = list[_S]()
    prev_end_idx = 0
385
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
386

387
    for match in _resolve_matches(prompt, mm_matches):
388
389
390
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
391
        if item_idx >= mm_item_counts.get(modality, 0):
392
393
394
395
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
396

397
        repl_info = match.prompt_repl
398
399
400
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
401
            repl_seq = replacement.full.text
402
403
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
404
            repl_seq = replacement.full.token_ids
405
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
406
407
408
409
410
411
412
413
414
415
416

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
417
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
418
    mm_item_counts: Mapping[str, int],
419
) -> list[int]:
420
421
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
422
423
        return prompt

424
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
425
426

    return flatten_2d_lists(token_id_seqs)
427
428


429
430
def replace_text_matches(
    prompt: str,
431
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
432
    mm_item_counts: Mapping[str, int],
433
) -> str:
434
435
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
436
        return prompt
437

438
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
439
440

    return "".join(texts)
441
442


443
444
def _iter_placeholders(
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
445
    prompt: list[int],
446
    mm_item_counts: Mapping[str, int],
447
) -> Iterable[PlaceholderFeaturesInfo]:
448
449
450
451
452
453
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_repls` takes priority.
454

455
456
    Note that empty matches are ignored.
    """
457
    prompt_len = len(prompt)
458
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
459
460
461
462
463

    start_idx = 0
    while start_idx < prompt_len:
        found = False

464
465
466
        for modality, modality_repls in mm_prompt_repls.items():
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
467
                continue
468

469
470
            for repl_info in modality_repls:
                replacement = repl_info.get_replacement(item_idx)
471
472
473
                repl_tokens_full = replacement.full.token_ids
                repl_len_full = len(repl_tokens_full)
                end_idx_full = start_idx + repl_len_full
474

475
                if repl_len_full == 0 or end_idx_full > prompt_len:
476
477
                    continue

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
                if prompt[start_idx:end_idx_full] == repl_tokens_full:
                    repl_tokens_feat = replacement.features.token_ids

                    try:
                        match = next(
                            iter_token_matches(repl_tokens_full,
                                               repl_tokens_feat))
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
                            tokens=repl_tokens_feat,
                        )
                    except StopIteration:
                        raise AssertionError(
                            f"{repl_tokens_feat=} should be a "
                            f"subsequence of {repl_tokens_full=}") from None
495

496
                    # Exclude overlapping matches
497
                    start_idx = end_idx_full
498
499
500
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
501

502
503
            if found:
                break  # Go back to the outer while loop
504
505
506

        if not found:
            start_idx += 1
507
508


509
def find_mm_placeholders(
510
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
511
512
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
513
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
514
515
516
517
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


518
519
520
521
522
523
524
525
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

526
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
544
    ) -> Optional[MultiModalKwargsItem]:
545
546
547
548
549
550
551
552
553
554
555
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

556
557
558
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
559
560
561
562
563
564
565
566
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
567
        output_kwargs: MultiModalKwargsItem,
568
569
570
571
572
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
573
574
575
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
576
        self._cache.put(cache_key, output_kwargs)
577
578


579
class BaseProcessingInfo:
580
    """Base class to provide the information necessary for data processing."""
581

582
583
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
584

585
586
587
588
589
590
591
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
592
593
        return self.ctx.tokenizer

594
    def get_hf_config(self) -> PretrainedConfig:
595
596
        return self.ctx.get_hf_config()

597
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
598
599
600
601
602
603
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
629

630
631

class BaseMultiModalProcessor(ABC, Generic[_I]):
632
    """
633
    Abstract base class to process multi-modal inputs to be used in vLLM.
634
635

    Not to be confused with :class:`transformers.ProcessorMixin`.
636
637
    """

638
    def __init__(self,
639
640
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
641
642
643
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
644
645
        super().__init__()

646
647
        self.info = info
        self.dummy_inputs = dummy_inputs
648
649
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
650

651
652
        self.data_parser = self._get_data_parser()

653
    def __call__(
654
        self,
655
656
        prompt: str,
        mm_data: MultiModalDataDict,
657
        hf_processor_mm_kwargs: Mapping[str, object],
658
    ) -> MultiModalInputs:
659
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
660

661
662
    def _get_data_parser(self) -> MultiModalDataParser:
        """
663
        Construct a parser to preprocess multi-modal data items
664
665
666
667
668
669
670
671
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
672
673
674
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
675
676
677
678
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
679
        mm_items = self.data_parser.parse_mm_data(mm_data)
680

681
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
682
683
684
685
686
687
688
689
690
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
691

692
693
694
695
696
697
698
699
700
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

701
702
    @abstractmethod
    def _get_prompt_replacements(
703
        self,
704
        mm_items: MultiModalDataItems,
705
706
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
707
708
709
710
711
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

712
713
714
715
716
717
718
719
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
720
721
        """
        raise NotImplementedError
722

723
    def _find_mm_placeholders(
724
        self,
725
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
726
        new_token_ids: list[int],
727
        mm_item_counts: Mapping[str, int],
728
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
729
730
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
731

732
    def _get_hf_mm_data(
733
        self,
734
        mm_items: MultiModalDataItems,
735
736
737
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
738

739
740
741
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
742

743
744
        return processor_data, passthrough_data

745
746
747
    def _call_hf_processor(
        self,
        prompt: str,
748
749
750
751
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
752
    ) -> BatchFeature:
753
754
755
756
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
757
758
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
759
760
            dict(text=prompt, **mm_data),
            mm_kwargs,
761
762
        )

763
    def _apply_hf_processor_text_mm(
764
        self,
765
        prompt_text: str,
766
        mm_items: MultiModalDataItems,
767
768
769
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
770
771
        Apply the HF processor on the prompt text and multi-modal data
        together.
772
773
774
775
776
777
778
779
780
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
781

782
        prompt_ids, = processed_data.pop("input_ids").tolist()
783

784
785
786
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
787
        )
788

789
790
        return prompt_ids, mm_kwargs

791
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
792
        """
793
        Apply the HF processor on the prompt text only.
794

795
796
797
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
798
        """
799
        prompt_ids, _ = self._apply_hf_processor_text_mm(
800
801
802
803
804
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

836
837
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
838
            mm_counts,
839
        )
840

841
        _, mm_kwargs = self._apply_hf_processor_text_mm(
842
            prompt_text=dummy_inputs.prompt_text,
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_replacement: bool,
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the prompt text and multi-modal data.

        Note:
            If :code:`enable_hf_prompt_replacement=False`, the prompt should
            correspond to the multi-modal items.
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_replacement:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

        mm_missing_kwargs = self._apply_hf_processor_mm_only(
            mm_items=mm_items,
878
879
880
881
882
883
884
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
885
        prompt: Union[str, list[int]],
886
887
888
889
890
891
892
893
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
894
        model_id = self.info.model_id
895

896
897
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
898
899
            return self._apply_hf_processor_main(
                prompt=prompt,
900
901
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
902
                enable_hf_prompt_replacement=True,
903
904
            )

905
        mm_maybe_cached_kw_items = {
906
907
908
909
910
911
912
913
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
914
915
916
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
917
918
919
920
921
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
922
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
923

924
925
926
927
928
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
        # so we need to pass `enable_hf_prompt_replacement=False`
        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_missing_data_items,
929
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
930
            enable_hf_prompt_replacement=False,
931
932
933
934
935
936
937
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

938
939
940
941
942
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
943
944
945
946
947
948
949
950
951
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
952
                        kw_item,
953
954
955
956
                    )

                    mm_missing_next_idx[modality] += 1

957
                merged_kw_items.append(kw_item)
958
959

        if self.enable_sanity_checks:
960
            mm_missing_counts = mm_missing_data_items.get_all_counts()
961
962
963
964
965
966
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

967
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
968
969

        return prompt_ids, mm_kwargs
970

971
    def _bind_and_group_repls(
972
        self,
973
        prompt_repls: list[PromptReplacement],
974
975
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
976

977
978
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
979

980
981
982
983
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
984
985
        detect that HF has performed processing via
        :meth:`_find_placeholders_by_modality`.
986

987
988
        This is useful in cases where :meth:`_find_placeholders_by_modality`
        cannot be reliably used to detect whether HF has performed processing.
989
990
991
        """
        return False

992
993
994
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
995
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
996
        mm_item_counts: Mapping[str, int],
997
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
998
        tokenizer = self.info.get_tokenizer()
999

1000
1001
1002
1003
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
1004
1005
        mm_match_counts = {
            modality: len(matches)
1006
            for modality, matches in mm_token_matches.items()
1007
        }
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
1020
1021
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1022
1023
1024
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
1025
                mm_token_matches,
1026
                mm_item_counts,
1027
1028
            )

1029
1030
1031
1032
1033
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
1034
        else:
1035
            text = decode_tokens(tokenizer, token_ids)
1036

1037
1038
1039
1040
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
1041
1042
            text = replace_text_matches(
                text,
1043
                mm_text_matches,
1044
                mm_item_counts,
1045
1046
            )

1047
1048
1049
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
1060
1061

        return token_ids, text, placeholders
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1086
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        mm_item_counts: Mapping[str, int],
        *,
        allow_missing: bool = False,
    ) -> Mapping[str, int]:
        missing_repl_counts = dict[str, int]()

        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count and not allow_missing:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
                    f"corresponding to {item_count} {modality} items, but only "
                    f"found {len(placeholders)} prompt replacements! Either "
                    "the prompt text has missing/incorrect tokens for "
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

            missing_repl_counts[modality] = item_count - len(placeholders)

        return missing_repl_counts

1111
1112
    def apply(
        self,
1113
        prompt: Union[str, list[int]],
1114
        mm_data: MultiModalDataDict,
1115
        hf_processor_mm_kwargs: Mapping[str, object],
1116
    ) -> MultiModalInputs:
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1130
        mm_items = self._to_mm_items(mm_data)
1131

1132
1133
1134
1135
1136
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1137
            model_id = self.info.model_id
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1150
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
1151
            prompt,
1152
1153
1154
            mm_items,
            hf_processor_mm_kwargs,
        )
1155

1156
1157
1158
1159
1160
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1161
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1162

1163
        mm_item_counts = mm_items.get_all_counts()
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        hf_mm_placeholders = self._find_mm_placeholders(
            mm_prompt_repls,
            prompt_ids,
            mm_item_counts,
        )

        if self._always_apply_prompt_replacements():
            mm_missing_repl_counts = mm_item_counts
            mm_missing_repls = dict(mm_prompt_repls)
        else:
            mm_missing_repl_counts = self._validate_mm_placeholders(
                hf_mm_placeholders,
                mm_item_counts,
                allow_missing=True,
            )

1182
            mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1183
1184
1185
1186
1187
1188
1189
1190
            for modality, missing_repl_count in mm_missing_repl_counts.items():
                if missing_repl_count == 0:
                    mm_missing_repls[modality] = []
                elif missing_repl_count == mm_item_counts.get(modality, 0):
                    mm_missing_repls[modality] = mm_prompt_repls[modality]
                else:
                    raise ValueError("Partial prompt replacement within "
                                     f"{modality=} is not supported")
1191

1192
1193
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
1194
        if all(len(repls) == 0 for repls in mm_missing_repls.values()):
1195
            tokenizer = self.info.get_tokenizer()
1196
            prompt = decode_tokens(tokenizer, prompt_ids)
1197
            mm_placeholders = hf_mm_placeholders
1198
1199
1200
        else:
            (
                prompt_ids,
1201
                prompt,
1202
                missing_mm_placeholders,
1203
1204
            ) = self._apply_prompt_replacements(
                prompt_ids,
1205
1206
                mm_missing_repls,
                mm_missing_repl_counts,
1207
1208
            )

1209
1210
1211
1212
1213
1214
1215
1216
            mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1217

1218
        return MultiModalInputs(
1219
            type="multimodal",
1220
            prompt=prompt,
1221
            prompt_token_ids=prompt_ids,
1222
            mm_kwargs=mm_kwargs,
1223
            mm_hashes=mm_hashes,
1224
            mm_placeholders=mm_placeholder_ranges,
1225
        )