processing.py 40.1 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import defaultdict
4
5
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
                             Sequence)
6
from dataclasses import dataclass, field
7
from functools import lru_cache
8
9
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
10

11
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
12

13
14
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
15
from vllm.logger import init_logger
16
17
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
18
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
19

20
from .hasher import MultiModalHasher
21
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
22
23
                     MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
                     PlaceholderRange)
24
from .parse import MultiModalDataItems, MultiModalDataParser
25
26
27

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
28

29
logger = init_logger(__name__)
30
31

_S = TypeVar("_S", str, list[int])
32
_PromptSeq = Union[str, list[int]]
33

34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass
class PromptReplacementDetails:
    full: _PromptSeq
    """The full replacement."""

    features: _PromptSeq
    """
    The part of the replacement that corresponds to placeholder feature tokens.
    """

    @staticmethod
    def from_seq(seq: _PromptSeq):
        return PromptReplacementDetails(full=seq, features=seq)


_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]


53
@dataclass
54
class PromptReplacement:
55
56
57
58
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
    """

59
    modality: str
60
    """The modality for which the replacement is made."""
61

62
    target: _PromptSeq
63
    """The token sequence (or text) to find and replace."""
64

65
66
    replacement: Union[Callable[[int], _PromptRepl],
                       _PromptRepl] = field(repr=False)
67
    """
68
69
    Given the index of the processed item within :attr:`modality`,
    output the replacement token sequence (or text).
70

71
72
    For convenience, you can directly pass in the replacement token sequence
    (or text) instead of a function if it does not depend on the input.
73
74
    """

75
76
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
77
78
79
80
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
81
        )
82
83


84
85
86
87
88
89
90
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
91
92
93
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
94
95


96
97
98
99
100
101
102
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
103
104
105
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
106
107
108
109
110


class _HasModalityAttr(Protocol):
    modality: str

111

112
class _HasModalityProp(Protocol):
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
129
130
    tokenizer: AnyTokenizer = field(repr=False)

131
132
133
    _text: Optional[str]
    _token_ids: Optional[list[int]]

134
135
136
137
138
139
140
141
    @staticmethod
    def from_seq(tokenizer: AnyTokenizer, seq: _PromptSeq):
        return _BoundPromptSequence(
            tokenizer=tokenizer,
            _text=seq if isinstance(seq, str) else None,
            _token_ids=seq if isinstance(seq, list) else None,
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


164
165
166
167
168
169
@dataclass
class _BoundPromptReplacementGroup:
    full: _BoundPromptSequence
    features: _BoundPromptSequence


170
@dataclass
171
class BoundPromptReplacement:
172
173
174
175
176
    """
    A :class:`PromptReplacement` bound to a tokenizer to automatically
    convert :attr:`target` and the result of :meth:`get_replacement` between
    token sequence and text representations.
    """
177
    tokenizer: AnyTokenizer = field(repr=False)
178
179
    modality: str

180
    _target: _PromptSeq
181
182
    _replacement: Union[Callable[[int], _PromptRepl],
                        _PromptRepl] = field(repr=False)
183

184
    def __post_init__(self) -> None:
185
        self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
186
187
188

    @property
    def target(self) -> _BoundPromptSequence:
189
        """The token sequence (or text) to find and replace."""
190
        return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
191

192
    def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
193
194
195
196
        """
        Given the index of the processed item within :attr:`modality`,
        output the replacement token sequence (or text).
        """
197
198
199
200
201
202
203
204
205
206
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

207
208
209
210
211
212
213
214
215
216
        if not isinstance(replacement, PromptReplacementDetails):
            replacement = PromptReplacementDetails.from_seq(replacement)

        bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
                                                   replacement.full)
        bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
                                                       replacement.features)
        bound_replacement = _BoundPromptReplacementGroup(
            full=bound_full,
            features=bound_features,
217
218
219
220
221
222
223
224
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


225
226
227
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
228
229


230
231
232
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
233
) -> Generator[_TokenMatch]:
234
235
236
237
238
239
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
240
    match_len = len(match_ids)
241

242
243
    if match_len == 0:
        return
244

245
246
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
247
        end_idx = start_idx + match_len
248

249
250
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
251
252
253
254
255

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
256
257


258
259
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
260
    prompt_repl: BoundPromptReplacement
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
282
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
283
284
285
286
287
288
289
290
291
292
293
294
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
295
class _PromptReplacementTextMatch(_PromptReplacementMatch):
296
297
298
299
300
301
302
303
304
305
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

306

307
@dataclass
308
class PlaceholderFeaturesInfo:
309
    modality: str
310
    item_idx: int
311
    start_idx: int
312
    tokens: list[int]
313
314
315

    @property
    def length(self) -> int:
316
        return len(self.tokens)
317
318
319
320
321
322

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
323
324
325
326


def find_token_matches(
    prompt: list[int],
327
    prompt_repls: Sequence[BoundPromptReplacement],
328
) -> list[_PromptReplacementTokenMatch]:
329
330
331
332
333
334
335
336
337
338
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
339
    prompt_repls: Sequence[BoundPromptReplacement],
340
) -> list[_PromptReplacementTextMatch]:
341
342
343
344
345
346
347
348
349
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
350
    prompt: _PromptSeq,
351
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
352
) -> list[_PromptReplacementMatch]:
353
    """
354
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
355
    and sort them such that earlier matches take priority over later ones.
356
    """
357
358
    matches = [m for matches in mm_matches.values() for m in matches]

359
360
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
361

362
    for match in matches:
363
364
365
366
367
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
368

369
            seen_matches[idx] = match
370
371
372
373
374
375

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
376
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
377
    mm_item_counts: Mapping[str, int],
378
) -> list[_S]:
379
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
380
381
    out_seqs = list[_S]()
    prev_end_idx = 0
382
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
383

384
    for match in _resolve_matches(prompt, mm_matches):
385
386
387
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
388
        if item_idx >= mm_item_counts.get(modality, 0):
389
390
391
392
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
393

394
        repl_info = match.prompt_repl
395
396
397
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
398
            repl_seq = replacement.full.text
399
400
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
401
            repl_seq = replacement.full.token_ids
402
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
403
404
405
406
407
408
409
410
411
412
413

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
414
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
415
    mm_item_counts: Mapping[str, int],
416
) -> list[int]:
417
418
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
419
420
        return prompt

421
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
422
423

    return flatten_2d_lists(token_id_seqs)
424
425


426
427
def replace_text_matches(
    prompt: str,
428
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
429
    mm_item_counts: Mapping[str, int],
430
) -> str:
431
432
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
433
        return prompt
434

435
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
436
437

    return "".join(texts)
438
439


440
441
def _iter_placeholders(
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
442
    prompt: list[int],
443
    mm_item_counts: Mapping[str, int],
444
) -> Iterable[PlaceholderFeaturesInfo]:
445
446
447
448
449
450
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Matches are exclusive even when multiple modalities share
    the same placeholder tokens. In that case, the modality that
    appears earlier in `mm_prompt_repls` takes priority.
451

452
453
    Note that empty matches are ignored.
    """
454
    prompt_len = len(prompt)
455
    item_idx_by_modality = defaultdict[str, int](lambda: 0)
456
457
458
459
460

    start_idx = 0
    while start_idx < prompt_len:
        found = False

461
462
463
        for modality, modality_repls in mm_prompt_repls.items():
            item_idx = item_idx_by_modality[modality]
            if item_idx >= mm_item_counts.get(modality, 0):
464
                continue
465

466
467
            for repl_info in modality_repls:
                replacement = repl_info.get_replacement(item_idx)
468
469
470
                repl_tokens_full = replacement.full.token_ids
                repl_len_full = len(repl_tokens_full)
                end_idx_full = start_idx + repl_len_full
471

472
                if repl_len_full == 0 or end_idx_full > prompt_len:
473
474
                    continue

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
                if prompt[start_idx:end_idx_full] == repl_tokens_full:
                    repl_tokens_feat = replacement.features.token_ids

                    try:
                        match = next(
                            iter_token_matches(repl_tokens_full,
                                               repl_tokens_feat))
                        yield PlaceholderFeaturesInfo(
                            modality=modality,
                            item_idx=item_idx,
                            start_idx=start_idx + match.start_idx,
                            tokens=repl_tokens_feat,
                        )
                    except StopIteration:
                        raise AssertionError(
                            f"{repl_tokens_feat=} should be a "
                            f"subsequence of {repl_tokens_full=}") from None
492

493
                    # Exclude overlapping matches
494
                    start_idx = end_idx_full
495
496
497
                    item_idx_by_modality[modality] += 1
                    found = True
                    break
498

499
500
            if found:
                break  # Go back to the outer while loop
501
502
503

        if not found:
            start_idx += 1
504
505


506
def find_mm_placeholders(
507
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
508
509
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
510
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
511
512
513
514
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


515
516
517
518
519
520
521
522
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

523
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
541
    ) -> Optional[MultiModalKwargsItem]:
542
543
544
545
546
547
548
549
550
551
552
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

553
554
555
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
556
557
558
559
560
561
562
563
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
564
        output_kwargs: MultiModalKwargsItem,
565
566
567
568
569
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
570
571
572
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
573
        self._cache.put(cache_key, output_kwargs)
574
575


576
class BaseProcessingInfo:
577
    """Base class to provide the information necessary for data processing."""
578

579
580
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
581

582
583
584
585
586
587
588
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
589
590
        return self.ctx.tokenizer

591
    def get_hf_config(self) -> PretrainedConfig:
592
593
        return self.ctx.get_hf_config()

594
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
595
596
597
598
599
600
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
626

627
628

class BaseMultiModalProcessor(ABC, Generic[_I]):
629
    """
630
    Abstract base class to process multi-modal inputs to be used in vLLM.
631
632

    Not to be confused with :class:`transformers.ProcessorMixin`.
633
634
    """

635
    def __init__(self,
636
637
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
638
639
640
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
641
642
        super().__init__()

643
644
        self.info = info
        self.dummy_inputs = dummy_inputs
645
646
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
647

648
649
        self.data_parser = self._get_data_parser()

650
    def __call__(
651
        self,
652
653
        prompt: str,
        mm_data: MultiModalDataDict,
654
        hf_processor_mm_kwargs: Mapping[str, object],
655
    ) -> MultiModalInputs:
656
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
657

658
659
    def _get_data_parser(self) -> MultiModalDataParser:
        """
660
        Construct a parser to preprocess multi-modal data items
661
662
663
664
665
666
667
668
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
669
670
671
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
672
673
674
675
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
676
        mm_items = self.data_parser.parse_mm_data(mm_data)
677

678
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
679
680
681
682
683
684
685
686
687
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
688

689
690
691
692
693
694
695
696
697
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

698
699
    @abstractmethod
    def _get_prompt_replacements(
700
        self,
701
        mm_items: MultiModalDataItems,
702
703
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
704
705
706
707
708
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

709
710
711
712
713
714
715
716
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
717
718
        """
        raise NotImplementedError
719

720
    def _find_mm_placeholders(
721
        self,
722
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
723
        new_token_ids: list[int],
724
        mm_item_counts: Mapping[str, int],
725
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
726
727
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
728

729
    def _get_hf_mm_data(
730
        self,
731
        mm_items: MultiModalDataItems,
732
733
734
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
735

736
737
738
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
739

740
741
        return processor_data, passthrough_data

742
743
744
    def _call_hf_processor(
        self,
        prompt: str,
745
746
747
748
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
749
    ) -> BatchFeature:
750
751
752
753
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
754
755
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
756
757
            dict(text=prompt, **mm_data),
            mm_kwargs,
758
759
        )

760
    def _apply_hf_processor_text_mm(
761
        self,
762
        prompt_text: str,
763
        mm_items: MultiModalDataItems,
764
765
766
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
767
768
        Apply the HF processor on the prompt text and multi-modal data
        together.
769
770
771
772
773
774
775
776
777
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
778

779
        prompt_ids, = processed_data.pop("input_ids").tolist()
780

781
782
783
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
784
        )
785

786
787
        return prompt_ids, mm_kwargs

788
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
789
        """
790
        Apply the HF processor on the prompt text only.
791

792
793
794
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
795
        """
796
        prompt_ids, _ = self._apply_hf_processor_text_mm(
797
798
799
800
801
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

833
834
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
835
            mm_counts,
836
        )
837

838
        _, mm_kwargs = self._apply_hf_processor_text_mm(
839
            prompt_text=dummy_inputs.prompt_text,
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_replacement: bool,
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the prompt text and multi-modal data.

        Note:
            If :code:`enable_hf_prompt_replacement=False`, the prompt should
            correspond to the multi-modal items.
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_replacement:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

        mm_missing_kwargs = self._apply_hf_processor_mm_only(
            mm_items=mm_items,
875
876
877
878
879
880
881
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
882
        prompt: Union[str, list[int]],
883
884
885
886
887
888
889
890
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
891
        model_id = self.info.model_id
892

893
894
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
895
896
            return self._apply_hf_processor_main(
                prompt=prompt,
897
898
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
899
                enable_hf_prompt_replacement=True,
900
901
            )

902
        mm_maybe_cached_kw_items = {
903
904
905
906
907
908
909
910
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
911
912
913
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
914
915
916
917
918
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
919
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
920

921
922
923
924
925
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
        # so we need to pass `enable_hf_prompt_replacement=False`
        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_missing_data_items,
926
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
927
            enable_hf_prompt_replacement=False,
928
929
930
931
932
933
934
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

935
936
937
938
939
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
940
941
942
943
944
945
946
947
948
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
949
                        kw_item,
950
951
952
953
                    )

                    mm_missing_next_idx[modality] += 1

954
                merged_kw_items.append(kw_item)
955
956

        if self.enable_sanity_checks:
957
            mm_missing_counts = mm_missing_data_items.get_all_counts()
958
959
960
961
962
963
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

964
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
965
966

        return prompt_ids, mm_kwargs
967

968
    def _bind_and_group_repls(
969
        self,
970
        prompt_repls: list[PromptReplacement],
971
972
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
973

974
975
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
976

977
978
979
980
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
981
982
        detect that HF has performed processing via
        :meth:`_find_placeholders_by_modality`.
983

984
985
        This is useful in cases where :meth:`_find_placeholders_by_modality`
        cannot be reliably used to detect whether HF has performed processing.
986
987
988
        """
        return False

989
990
991
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
992
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
993
        mm_item_counts: Mapping[str, int],
994
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
995
        tokenizer = self.info.get_tokenizer()
996

997
998
999
1000
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
1001
1002
        mm_match_counts = {
            modality: len(matches)
1003
            for modality, matches in mm_token_matches.items()
1004
        }
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
1017
1018
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
1019
1020
1021
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
1022
                mm_token_matches,
1023
                mm_item_counts,
1024
1025
            )

1026
1027
1028
1029
1030
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
1031
        else:
1032
            text = decode_tokens(tokenizer, token_ids)
1033

1034
1035
1036
1037
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
1038
1039
            text = replace_text_matches(
                text,
1040
                mm_text_matches,
1041
                mm_item_counts,
1042
1043
            )

1044
1045
1046
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
1057
1058

        return token_ids, text, placeholders
1059

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1083
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        mm_item_counts: Mapping[str, int],
        *,
        allow_missing: bool = False,
    ) -> Mapping[str, int]:
        missing_repl_counts = dict[str, int]()

        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count and not allow_missing:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
                    f"corresponding to {item_count} {modality} items, but only "
                    f"found {len(placeholders)} prompt replacements! Either "
                    "the prompt text has missing/incorrect tokens for "
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

            missing_repl_counts[modality] = item_count - len(placeholders)

        return missing_repl_counts

1108
1109
    def apply(
        self,
1110
        prompt: Union[str, list[int]],
1111
        mm_data: MultiModalDataDict,
1112
        hf_processor_mm_kwargs: Mapping[str, object],
1113
    ) -> MultiModalInputs:
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1127
        mm_items = self._to_mm_items(mm_data)
1128

1129
1130
1131
1132
1133
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1134
            model_id = self.info.model_id
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1147
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
1148
            prompt,
1149
1150
1151
            mm_items,
            hf_processor_mm_kwargs,
        )
1152

1153
1154
1155
1156
1157
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1158
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1159

1160
        mm_item_counts = mm_items.get_all_counts()
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        hf_mm_placeholders = self._find_mm_placeholders(
            mm_prompt_repls,
            prompt_ids,
            mm_item_counts,
        )

        if self._always_apply_prompt_replacements():
            mm_missing_repl_counts = mm_item_counts
            mm_missing_repls = dict(mm_prompt_repls)
        else:
            mm_missing_repl_counts = self._validate_mm_placeholders(
                hf_mm_placeholders,
                mm_item_counts,
                allow_missing=True,
            )

1179
            mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1180
1181
1182
1183
1184
1185
1186
1187
            for modality, missing_repl_count in mm_missing_repl_counts.items():
                if missing_repl_count == 0:
                    mm_missing_repls[modality] = []
                elif missing_repl_count == mm_item_counts.get(modality, 0):
                    mm_missing_repls[modality] = mm_prompt_repls[modality]
                else:
                    raise ValueError("Partial prompt replacement within "
                                     f"{modality=} is not supported")
1188

1189
1190
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
1191
        if all(len(repls) == 0 for repls in mm_missing_repls.values()):
1192
            tokenizer = self.info.get_tokenizer()
1193
            prompt = decode_tokens(tokenizer, prompt_ids)
1194
            mm_placeholders = hf_mm_placeholders
1195
1196
1197
        else:
            (
                prompt_ids,
1198
                prompt,
1199
                missing_mm_placeholders,
1200
1201
            ) = self._apply_prompt_replacements(
                prompt_ids,
1202
1203
                mm_missing_repls,
                mm_missing_repl_counts,
1204
1205
            )

1206
1207
1208
1209
1210
1211
1212
1213
            mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1214

1215
        return MultiModalInputs(
1216
            type="multimodal",
1217
            prompt=prompt,
1218
            prompt_token_ids=prompt_ids,
1219
            mm_kwargs=mm_kwargs,
1220
            mm_hashes=mm_hashes,
1221
            mm_placeholders=mm_placeholder_ranges,
1222
        )