processing.py 36.1 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import defaultdict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
8
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
9

10
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
11

12
13
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
14
from vllm.logger import init_logger
15
16
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
17
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
18

19
from .hasher import MultiModalHasher
20
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
21
22
                     MultiModalInputsV2, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
23
from .parse import MultiModalDataItems, MultiModalDataParser
24
25
26

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
27

28
logger = init_logger(__name__)
29
30

_S = TypeVar("_S", str, list[int])
31
_PromptSeq = Union[str, list[int]]
32

33
34

@dataclass
35
36
class PromptReplacement:
    modality: str
37
    """The modality for which the replacement is made."""
38

39
40
    target: _PromptSeq
    """The text or token sequence to find and replace."""
41

42
43
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
44
    """
45
46
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
47

48
49
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
50
51
    """

52
53
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
54
55
56
57
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
58
        )
59
60


61
62
63
64
65
66
67
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
68
69
70
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
71
72


73
74
75
76
77
78
79
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
80
81
82
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
83
84
85
86
87


class _HasModalityAttr(Protocol):
    modality: str

88

89
class _HasModalityProp(Protocol):
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
106
107
    tokenizer: AnyTokenizer = field(repr=False)

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
134
class BoundPromptReplacement:
135
    tokenizer: AnyTokenizer = field(repr=False)
136
137
    modality: str

138
139
140
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
141

142
143
144
145
146
147
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
148

149
150
151
152
153
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


178
179
180
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
181
182


183
184
185
186
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
187
188
189
190
191
192
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
193
    match_len = len(match_ids)
194

195
196
    if match_len == 0:
        return
197

198
199
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
200
        end_idx = start_idx + match_len
201

202
203
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
204
205
206
207
208

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
209
210


211
212
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
213
    prompt_repl: BoundPromptReplacement
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
235
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
236
237
238
239
240
241
242
243
244
245
246
247
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
248
class _PromptReplacementTextMatch(_PromptReplacementMatch):
249
250
251
252
253
254
255
256
257
258
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

259

260
@dataclass
261
class PlaceholderInfo:
262
    modality: str
263
    item_idx: int
264
    start_idx: int
265
    replacement: list[int]
266
267
268

    @property
    def length(self) -> int:
269
        return len(self.replacement)
270
271
272
273
274
275

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
276
277
278
279


def find_token_matches(
    prompt: list[int],
280
    prompt_repls: Sequence[BoundPromptReplacement],
281
) -> list[_PromptReplacementTokenMatch]:
282
283
284
285
286
287
288
289
290
291
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
292
    prompt_repls: Sequence[BoundPromptReplacement],
293
) -> list[_PromptReplacementTextMatch]:
294
295
296
297
298
299
300
301
302
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
303
    prompt: _PromptSeq,
304
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
305
) -> list[_PromptReplacementMatch]:
306
    """
307
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
308
    and sort them such that earlier matches take priority over later ones.
309
    """
310
311
    matches = [m for matches in mm_matches.values() for m in matches]

312
313
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
314

315
    for match in matches:
316
317
318
319
320
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
321

322
            seen_matches[idx] = match
323
324
325
326
327
328

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
329
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
330
    mm_item_counts: Mapping[str, int],
331
) -> list[_S]:
332
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
333
334
    out_seqs = list[_S]()
    prev_end_idx = 0
335
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
336

337
    for match in _resolve_matches(prompt, mm_matches):
338
339
340
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
341
        if item_idx >= mm_item_counts.get(modality, 0):
342
343
344
345
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
346

347
        repl_info = match.prompt_repl
348
349
350
351
352
353
354
355
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
356
357
358
359
360
361
362
363
364
365
366

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
367
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
368
    mm_item_counts: Mapping[str, int],
369
) -> list[int]:
370
371
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
372
373
        return prompt

374
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
375
376

    return flatten_2d_lists(token_id_seqs)
377
378


379
380
def replace_text_matches(
    prompt: str,
381
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
382
    mm_item_counts: Mapping[str, int],
383
) -> str:
384
385
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
386
        return prompt
387

388
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
389
390

    return "".join(texts)
391
392


393
394
395
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
396
    modality_repls: Sequence[BoundPromptReplacement],
397
    modal_item_count: int,
398
) -> Iterable[PlaceholderInfo]:
399
    if modal_item_count == 0:
400
        return
401

402
    prompt_len = len(prompt)
403
    item_idx = 0
404
405
406
407
408
409

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
410
            replacement = repl_info.get_replacement(item_idx)
411
412
413
414
415
416
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
417

418
            if prompt[start_idx:end_idx] == repl_tokens:
419
                yield PlaceholderInfo(
420
                    modality=modality,
421
                    item_idx=item_idx,
422
423
424
425
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

426
427
                item_idx += 1
                if item_idx >= modal_item_count:
428
429
430
431
432
433
434
435
436
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
437
438


439
def _iter_placeholders(
440
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
441
    prompt: list[int],
442
    mm_item_counts: Mapping[str, int],
443
) -> Iterable[PlaceholderInfo]:
444
    """
445
446
    For each modality, yield each set of placeholder tokens found in
    :code:`prompt`.
447
448
449

    Note that empty matches are ignored.
    """
450
    for modality, modal_item_count in mm_item_counts.items():
451
        if modality in mm_prompt_repls:
452
453
454
            yield from _iter_modality_placeholders(
                prompt,
                modality,
455
                mm_prompt_repls[modality],
456
                modal_item_count,
457
458
            )

459

460
def find_mm_placeholders(
461
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
462
463
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
464
) -> Mapping[str, list[PlaceholderInfo]]:
465
466
467
468
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


469
470
471
472
473
474
475
476
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

477
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
495
    ) -> Optional[MultiModalKwargsItem]:
496
497
498
499
500
501
502
503
504
505
506
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

507
508
509
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
510
511
512
513
514
515
516
517
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
518
        output_kwargs: MultiModalKwargsItem,
519
520
521
522
523
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
524
525
526
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
527
        self._cache.put(cache_key, output_kwargs)
528
529


530
531
class BaseProcessingInfo:
    """Base class containing information to perform processing."""
532

533
534
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
535

536
537
538
539
540
541
542
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
543
544
        return self.ctx.tokenizer

545
    def get_hf_config(self) -> PretrainedConfig:
546
547
        return self.ctx.get_hf_config()

548
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
549
550
551
552
553
554
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
580

581
582

class BaseMultiModalProcessor(ABC, Generic[_I]):
583
    """
584
    Abstract base class to process multi-modal inputs to be used in vLLM.
585
586

    Not to be confused with :class:`transformers.ProcessorMixin`.
587
588
    """

589
    def __init__(self,
590
591
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
592
593
594
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
595
596
        super().__init__()

597
598
        self.info = info
        self.dummy_inputs = dummy_inputs
599
600
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
601

602
603
        self.data_parser = self._get_data_parser()

604
    def __call__(
605
        self,
606
607
        prompt: str,
        mm_data: MultiModalDataDict,
608
        hf_processor_mm_kwargs: Mapping[str, object],
609
    ) -> MultiModalInputsV2:
610
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
611

612
613
    def _get_data_parser(self) -> MultiModalDataParser:
        """
614
        Construct a parser to preprocess multi-modal data items
615
616
617
618
619
620
621
622
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
623
624
625
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
626
627
628
629
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
630
        mm_items = self.data_parser.parse_mm_data(mm_data)
631

632
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
633
634
635
636
637
638
639
640
641
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
642

643
644
645
646
647
648
649
650
651
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

652
653
    @abstractmethod
    def _get_prompt_replacements(
654
        self,
655
        mm_items: MultiModalDataItems,
656
657
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
658
659
660
661
662
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

663
664
665
666
667
668
669
670
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
671
672
        """
        raise NotImplementedError
673

674
    def _find_mm_placeholders(
675
        self,
676
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
677
        new_token_ids: list[int],
678
        mm_item_counts: Mapping[str, int],
679
    ) -> Mapping[str, list[PlaceholderInfo]]:
680
681
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
682

683
    def _get_hf_mm_data(
684
        self,
685
        mm_items: MultiModalDataItems,
686
687
688
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
689

690
691
692
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
693

694
695
        return processor_data, passthrough_data

696
697
698
    def _call_hf_processor(
        self,
        prompt: str,
699
700
701
702
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
703
    ) -> BatchFeature:
704
705
706
707
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
708
709
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
710
711
            dict(text=prompt, **mm_data),
            mm_kwargs,
712
713
        )

714
715
    def _apply_hf_processor(
        self,
716
        prompt_text: str,
717
        mm_items: MultiModalDataItems,
718
719
720
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
721
722
        Wrapper of :meth:`_call_hf_processor` that applies
        additional pre-processing and post-processing.
723
724
725
726
727
728
729
730
731
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
732

733
        prompt_ids, = processed_data.pop("input_ids").tolist()
734

735
736
737
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
738
        )
739

740
741
742
743
744
745
746
747
748
749
750
751
        return prompt_ids, mm_kwargs

    def _apply_hf_processor_missing(
        self,
        prompt_text: str,
        mm_missing_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ):
        """
        Apply the HF processor on the full prompt text, but only on the
        multi-modal data that are missing from the cache.

752
753
754
755
756
        Note:
            We pass prompt text and multi-modal data into the HF processor
            in separate calls to avoid HF prompt replacement being done for
            cached items; instead, we rely on our own prompt replacement logic
            (:meth:`_get_prompt_replacements`) for the full text.
757
        """
758
        mm_missing_counts = mm_missing_data_items.get_all_counts()
759
760
761
762
763
764
765
766
767

        prompt_ids, _ = self._apply_hf_processor(
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

        # Some HF processors (e.g. Qwen2-VL) expect corresponding
        # multi-modal tokens to be in the prompt text
768
769
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
770
771
            mm_missing_counts,
        )
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791

        _, mm_missing_kwargs = self._apply_hf_processor(
            prompt_text=dummy_inputs.prompt_text,
            mm_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
        prompt_text: str,
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
792
        model_id = self.info.model_id
793

794
795
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
796
797
798
799
800
801
            return self._apply_hf_processor(
                prompt_text=prompt_text,
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            )

802
        mm_maybe_cached_kw_items = {
803
804
805
806
807
808
809
810
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
811
812
813
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
814
815
816
817
818
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
819
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
820
821
822
823
824
825
826
827
828
829
830
831

        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
            prompt_text=prompt_text,
            mm_missing_data_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

832
833
834
835
836
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
837
838
839
840
841
842
843
844
845
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
846
                        kw_item,
847
848
849
850
                    )

                    mm_missing_next_idx[modality] += 1

851
                merged_kw_items.append(kw_item)
852
853

        if self.enable_sanity_checks:
854
            mm_missing_counts = mm_missing_data_items.get_all_counts()
855
856
857
858
859
860
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

861
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
862
863

        return prompt_ids, mm_kwargs
864

865
    def _bind_and_group_repls(
866
        self,
867
        prompt_repls: list[PromptReplacement],
868
869
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
870

871
872
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
873

874
875
876
877
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
878
879
        detect that HF has performed processing via
        :meth:`_find_placeholders_by_modality`.
880

881
882
        This is useful in cases where :meth:`_find_placeholders_by_modality`
        cannot be reliably used to detect whether HF has performed processing.
883
884
885
        """
        return False

886
887
888
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
889
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
890
        mm_item_counts: Mapping[str, int],
891
892
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
        tokenizer = self.info.get_tokenizer()
893

894
895
896
897
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
898
899
        mm_match_counts = {
            modality: len(matches)
900
            for modality, matches in mm_token_matches.items()
901
        }
902
903
904
905
906
907
908
909
910
911
912
913

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
914
915
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
916
917
918
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
919
                mm_token_matches,
920
                mm_item_counts,
921
922
            )

923
924
925
926
927
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
928
        else:
929
            text = decode_tokens(tokenizer, token_ids)
930

931
932
933
934
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
935
936
            text = replace_text_matches(
                text,
937
                mm_text_matches,
938
                mm_item_counts,
939
940
            )

941
942
943
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
944
945
946
947
948
949
950
951
952
953
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
954
955

        return token_ids, text, placeholders
956

957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
980
        mm_placeholders: Mapping[str, list[PlaceholderInfo]],
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        mm_item_counts: Mapping[str, int],
        *,
        allow_missing: bool = False,
    ) -> Mapping[str, int]:
        missing_repl_counts = dict[str, int]()

        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count and not allow_missing:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
                    f"corresponding to {item_count} {modality} items, but only "
                    f"found {len(placeholders)} prompt replacements! Either "
                    "the prompt text has missing/incorrect tokens for "
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

            missing_repl_counts[modality] = item_count - len(placeholders)

        return missing_repl_counts

1005
1006
1007
1008
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
1009
        hf_processor_mm_kwargs: Mapping[str, object],
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1024
        mm_items = self._to_mm_items(mm_data)
1025

1026
1027
1028
1029
1030
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1031
            model_id = self.info.model_id
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1044
1045
1046
1047
1048
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
            prompt_text,
            mm_items,
            hf_processor_mm_kwargs,
        )
1049

1050
1051
1052
1053
1054
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1055
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1056

1057
        mm_item_counts = mm_items.get_all_counts()
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        hf_mm_placeholders = self._find_mm_placeholders(
            mm_prompt_repls,
            prompt_ids,
            mm_item_counts,
        )

        if self._always_apply_prompt_replacements():
            mm_missing_repl_counts = mm_item_counts
            mm_missing_repls = dict(mm_prompt_repls)
        else:
            mm_missing_repl_counts = self._validate_mm_placeholders(
                hf_mm_placeholders,
                mm_item_counts,
                allow_missing=True,
            )

1076
            mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1077
1078
1079
1080
1081
1082
1083
1084
            for modality, missing_repl_count in mm_missing_repl_counts.items():
                if missing_repl_count == 0:
                    mm_missing_repls[modality] = []
                elif missing_repl_count == mm_item_counts.get(modality, 0):
                    mm_missing_repls[modality] = mm_prompt_repls[modality]
                else:
                    raise ValueError("Partial prompt replacement within "
                                     f"{modality=} is not supported")
1085

1086
1087
1088
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
        if all(len(repls) == 0 for repls in mm_missing_repls.items()):
1089
            tokenizer = self.info.get_tokenizer()
1090
1091
            prompt_text = decode_tokens(tokenizer, prompt_ids)
            mm_placeholders = hf_mm_placeholders
1092
1093
1094
1095
        else:
            (
                prompt_ids,
                prompt_text,
1096
                missing_mm_placeholders,
1097
1098
            ) = self._apply_prompt_replacements(
                prompt_ids,
1099
1100
                mm_missing_repls,
                mm_missing_repl_counts,
1101
1102
            )

1103
1104
1105
1106
1107
1108
1109
1110
            mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1111
1112
1113

        return MultiModalInputsV2(
            type="multimodal",
1114
1115
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
1116
            mm_kwargs=mm_kwargs,
1117
            mm_hashes=mm_hashes,
1118
            mm_placeholders=mm_placeholder_ranges,
1119
        )