processing.py 38.2 KB
Newer Older
1
import pickle
2
3
import re
from abc import ABC, abstractmethod
4
from collections import defaultdict
5
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
6
from dataclasses import dataclass, field
7
from functools import lru_cache
8
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
9

10
import numpy as np
11
import numpy.typing as npt
12
import torch
13
from blake3 import blake3
14
from PIL import Image
15
from transformers import BatchFeature, ProcessorMixin
16

17
from vllm.inputs import DummyData, InputProcessingContext
18
from vllm.logger import init_logger
19
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
20
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
21

22
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
23
24
                     MultiModalInputsV2, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
25
from .parse import MultiModalDataItems, MultiModalDataParser
26

27
logger = init_logger(__name__)
28
29

_S = TypeVar("_S", str, list[int])
30
_PromptSeq = Union[str, list[int]]
31

32
33

@dataclass
34
35
class PromptReplacement:
    modality: str
36
    """The modality for which the replacement is made."""
37

38
39
    target: _PromptSeq
    """The text or token sequence to find and replace."""
40

41
42
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
43
    """
44
45
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
46

47
48
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
49
50
    """

51
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
52
        return _BoundPromptReplacement(
53
54
55
56
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
57
        )
58
59


60
61
62
63
64
65
66
67
68
69
70
71
72
73
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
74

75
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
76
77


78
79
80
81
82
83
84
85
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
86
87


88
89
90
91
92
93
94
95
96
97
98
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
99
100


101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

116

117
class _HasModalityProp(Protocol):
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
134
135
    tokenizer: AnyTokenizer = field(repr=False)

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
162
163
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
164
165
    modality: str

166
167
168
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
169

170
171
172
173
174
175
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
176

177
178
179
180
181
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


206
207
208
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
209
210


211
212
213
214
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
215
216
217
218
219
220
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
221
    match_len = len(match_ids)
222

223
224
    if match_len == 0:
        return
225

226
227
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
228
        end_idx = start_idx + match_len
229

230
231
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
232
233
234
235
236

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
237
238


239
240
241
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
263
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
264
265
266
267
268
269
270
271
272
273
274
275
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
276
class _PromptReplacementTextMatch(_PromptReplacementMatch):
277
278
279
280
281
282
283
284
285
286
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

287
288
289
290

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
291
    replacement: list[int]
292
293
294

    @property
    def length(self) -> int:
295
        return len(self.replacement)
296
297
298
299
300
301

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
302
303
304
305


def find_token_matches(
    prompt: list[int],
306
307
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
308
309
310
311
312
313
314
315
316
317
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
318
319
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
320
321
322
323
324
325
326
327
328
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
329
330
331
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
332
333
334
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
335
    """
336
337
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
338

339
    for match in matches:
340
341
342
343
344
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
345

346
            seen_matches[idx] = match
347
348
349
350
351
352

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
353
    matches: Sequence[_PromptReplacementMatch],
354
    mm_item_counts: Mapping[str, int],
355
356
357
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
358
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
359
360
361
362
363

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
364
        if item_idx >= mm_item_counts.get(modality, 0):
365
366
367
368
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
369

370
        repl_info = match.prompt_repl
371
372
373
374
375
376
377
378
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
379
380
381
382
383
384
385
386
387
388
389

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
390
    matches: Sequence[_PromptReplacementTokenMatch],
391
    mm_item_counts: Mapping[str, int],
392
393
394
395
396
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

397
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
398
399

    return flatten_2d_lists(token_id_seqs)
400
401


402
403
def replace_text_matches(
    prompt: str,
404
    matches: Sequence[_PromptReplacementTextMatch],
405
    mm_item_counts: Mapping[str, int],
406
407
408
409
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
410

411
    texts = _replace_matches(prompt, matches, mm_item_counts)
412
413

    return "".join(texts)
414
415


416
417
418
419
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
420
    modal_item_count: int,
421
) -> Iterable[_PlaceholderInfo]:
422
    if modal_item_count == 0:
423
        return
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
440

441
442
443
444
445
446
447
448
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
449
                if item_index >= modal_item_count:
450
451
452
453
454
455
456
457
458
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
459
460
461


def iter_placeholders(
462
    prompt_repls: Sequence[_BoundPromptReplacement],
463
    prompt: list[int],
464
    mm_item_counts: Mapping[str, int],
465
) -> Iterable[_PlaceholderInfo]:
466
467
468
469
470
471
472
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

473
    for modality, modal_item_count in mm_item_counts.items():
474
475
476
477
478
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
479
                modal_item_count,
480
481
            )

482

483
484
485
@dataclass
class ProcessorInputs:
    """Keyword arguments to :meth:`BaseMultiModalProcessor`."""
486
487
    prompt_text: str
    mm_data: MultiModalDataDict
488
489
490
491
492
493
494
495
496
497
498
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)


class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

499
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def _serialize_item(self, obj: object) -> bytes:
        # Simple cases
        if isinstance(obj, str):
            return obj.encode("utf-8")
        if isinstance(obj, bytes):
            return obj
517
        if isinstance(obj, Image.Image):
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            return obj.tobytes()

        # Convertible to NumPy arrays
        if isinstance(obj, torch.Tensor):
            obj = obj.numpy()
        if isinstance(obj, (int, float)):
            obj = np.array(obj)
        if isinstance(obj, np.ndarray):
            return obj.tobytes()

        logger.warning(
            "No serialization method found for %s. "
            "Falling back to pickle.", type(obj))

        return pickle.dumps(obj)

    def _item_to_bytes(
        self,
        key: str,
        obj: object,
    ) -> Iterable[tuple[bytes, bytes]]:
        # Recursive cases
        if isinstance(obj, (list, tuple)):
            for i, elem in enumerate(obj):
                yield from self._item_to_bytes(f"{key}.{i}", elem)
        elif isinstance(obj, dict):
            for k, v in obj.items():
                yield from self._item_to_bytes(f"{key}.{k}", v)
        else:
            key_bytes = self._serialize_item(key)
            value_bytes = self._serialize_item(obj)
            yield key_bytes, value_bytes

    def _hash_kwargs(self, **kwargs: object) -> str:
        hasher = blake3()

        for k, v in kwargs.items():
            for k_bytes, v_bytes in self._item_to_bytes(k, v):
                hasher.update(k_bytes)
                hasher.update(v_bytes)

        return hasher.hexdigest()

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
567
    ) -> Optional[MultiModalKwargsItem]:
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

        cache_key = self._hash_kwargs(model_id=model_id,
                                      **{modality: input_item},
                                      **input_kwargs)
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
590
        output_kwargs: MultiModalKwargsItem,
591
592
593
594
595
596
597
598
599
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
        cache_key = self._hash_kwargs(model_id=model_id,
                                      **{modality: input_item},
                                      **input_kwargs)
        self._cache.put(cache_key, output_kwargs)
600
601


602
class BaseMultiModalProcessor(ABC):
603
    """
604
    Abstract base class to process multi-modal inputs to be used in vLLM.
605
606
    """

607
608
609
610
611
    def __init__(self,
                 ctx: InputProcessingContext,
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
612
613
614
        super().__init__()

        self.ctx = ctx
615
616
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
617

618
    def __call__(
619
        self,
620
621
        prompt: str,
        mm_data: MultiModalDataDict,
622
        hf_processor_mm_kwargs: Mapping[str, object],
623
    ) -> MultiModalInputsV2:
624
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
    def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError

649
650
651
652
653
654
655
656
657
658
    def _get_data_parser(self) -> MultiModalDataParser:
        """
        Construct a data parser to preprocess multi-modal data items
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

659
660
661
662
663
    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
664
665
666
667
668
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

669
    def _to_mm_items(
670
671
672
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
673
674
675
676
677
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
        parser = self._get_data_parser()
678
679
680
681
682
683
684
685
686
687
688
689
        mm_items = parser.parse_mm_data(mm_data)

        mm_limits = self.ctx.get_mm_config().limit_per_prompt
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
690

691
692
693
694
695
696
697
698
699
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

700
701
    @abstractmethod
    def _get_prompt_replacements(
702
        self,
703
        mm_items: MultiModalDataItems,
704
705
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
706
707
708
709
710
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

711
712
713
714
715
716
717
718
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
719
720
        """
        raise NotImplementedError
721

722
723
    def _find_placeholders(
        self,
724
        all_prompt_repls: Sequence[_BoundPromptReplacement],
725
        new_token_ids: list[int],
726
        mm_item_counts: Mapping[str, int],
727
728
    ) -> list[_PlaceholderInfo]:
        return list(
729
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
730

731
    def _get_hf_mm_data(
732
        self,
733
734
        mm_items: MultiModalDataItems,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
735
736
        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
737

738
739
740
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
741

742
743
        return processor_data, passthrough_data

744
745
746
    def _call_hf_processor(
        self,
        prompt: str,
747
748
749
750
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
751
    ) -> BatchFeature:
752
753
754
755
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
756
        return self.ctx.call_hf_processor(
757
758
759
            self._get_hf_processor(**mm_kwargs),
            dict(text=prompt, **mm_data),
            mm_kwargs,
760
761
        )

762
763
    def _apply_hf_processor(
        self,
764
        prompt_text: str,
765
        mm_items: MultiModalDataItems,
766
767
768
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
769
770
        Wrapper of :meth:`_call_hf_processor` that applies
        additional pre-processing and post-processing.
771
772
773
774
775
776
777
778
779
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
780

781
        prompt_ids, = processed_data.pop("input_ids").tolist()
782

783
784
785
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
786
        )
787

788
789
790
791
792
793
794
795
796
797
798
799
        return prompt_ids, mm_kwargs

    def _apply_hf_processor_missing(
        self,
        prompt_text: str,
        mm_missing_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ):
        """
        Apply the HF processor on the full prompt text, but only on the
        multi-modal data that are missing from the cache.

800
801
802
803
804
        Note:
            We pass prompt text and multi-modal data into the HF processor
            in separate calls to avoid HF prompt replacement being done for
            cached items; instead, we rely on our own prompt replacement logic
            (:meth:`_get_prompt_replacements`) for the full text.
805
        """
806
        mm_missing_counts = mm_missing_data_items.get_all_counts()
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838

        prompt_ids, _ = self._apply_hf_processor(
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

        # Some HF processors (e.g. Qwen2-VL) expect corresponding
        # multi-modal tokens to be in the prompt text
        dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)

        _, mm_missing_kwargs = self._apply_hf_processor(
            prompt_text=dummy_inputs.prompt_text,
            mm_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
        prompt_text: str,
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
        model_id = self.ctx.model_config.model

839
840
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
841
842
843
844
845
846
            return self._apply_hf_processor(
                prompt_text=prompt_text,
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            )

847
        mm_maybe_cached_kw_items = {
848
849
850
851
852
853
854
855
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
856
857
858
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
859
860
861
862
863
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
864
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
865
866
867
868
869
870
871
872
873
874
875
876

        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
            prompt_text=prompt_text,
            mm_missing_data_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

877
878
879
880
881
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
882
883
884
885
886
887
888
889
890
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
891
                        kw_item,
892
893
894
895
                    )

                    mm_missing_next_idx[modality] += 1

896
                merged_kw_items.append(kw_item)
897
898

        if self.enable_sanity_checks:
899
            mm_missing_counts = mm_missing_data_items.get_all_counts()
900
901
902
903
904
905
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

906
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
907
908

        if self.enable_sanity_checks:
909
            mm_item_counts = mm_data_items.get_all_counts()
910
911
912
913

            for modality, item_count in mm_item_counts.items():
                for item_idx in range(item_count):
                    try:
914
                        mm_kwargs.get_item(modality, item_idx)
915
916
917
918
919
                    except Exception as e:
                        # Make it easy to set a breakpoint in the debugger
                        raise e

        return prompt_ids, mm_kwargs
920

921
922
    def _bind_prompt_replacements(
        self,
923
924
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
925
        tokenizer = self._get_tokenizer()
926

927
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
928

929
930
931
932
933
934
935
936
937
938
939
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
        detect that HF has performed processing via :meth:`_find_placeholders`.

        This is useful in cases where :meth:`_find_placeholders` cannot be
        reliably used to detect whether HF has performed processing or not.
        """
        return False

940
941
942
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
943
        prompt_repls: Sequence[_BoundPromptReplacement],
944
        mm_item_counts: Mapping[str, int],
945
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
946
        tokenizer = self._get_tokenizer()
947

948
        token_matches = find_token_matches(token_ids, prompt_repls)
949
950
951
952
        mm_match_counts = {
            modality: len(matches)
            for modality, matches in full_groupby_modality(token_matches)
        }
953
954
955
956
957
958
959
960
961
962
963
964

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
965
966
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
967
968
969
970
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
971
                mm_item_counts,
972
973
974
975
976
977
978
979
980
981
982
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
983
                mm_item_counts,
984
985
986
987
988
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

989
        placeholders = self._find_placeholders(matched_repls, token_ids,
990
                                               mm_item_counts)
991
992

        return token_ids, text, placeholders
993

994
995
996
997
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
998
        hf_processor_mm_kwargs: Mapping[str, object],
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1013
        mm_items = self._to_mm_items(mm_data)
1014

1015
1016
1017
1018
1019
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
            prompt_text,
            mm_items,
            hf_processor_mm_kwargs,
        )
1020

1021
1022
1023
1024
1025
1026
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
        prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
1027

1028
1029
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
1030
        mm_item_counts = mm_items.get_all_counts()
1031
1032
        all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
                                                   mm_item_counts)
1033

1034
        if all_placeholders and not self._always_apply_prompt_replacements():
1035
            tokenizer = self._get_tokenizer()
1036
1037
1038
1039
1040
1041
1042
1043
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
1044
                prompt_repls,
1045
                mm_item_counts,
1046
1047
            )

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        mm_placeholders = dict[str, list[PlaceholderRange]]()
        err_suffix = ("This suggests a problem with your implementation of "
                      "the merged multi-modal processor for this model, "
                      "particularly in the `_get_prompt_replacements` method.")

        for modality, placeholders in full_groupby_modality(all_placeholders):
            if modality not in mm_items:
                raise AssertionError(
                    f"Expected no placeholders for {modality=}, "
                    f"but found {placeholders=}. Input items: {mm_items}"
                    f"\n{err_suffix}")

            if len(placeholders) != len(mm_items[modality]):
                raise AssertionError(
                    f"Expected length of {placeholders=} for {modality=} "
                    f"to equal that of input items: {mm_items[modality]}"
                    f"\n{err_suffix}")

            mm_placeholders[modality] = [
                item.to_range() for item in placeholders
            ]
1069
1070
1071

        return MultiModalInputsV2(
            type="multimodal",
1072
1073
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
1074
1075
1076
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
1077

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
    ) -> list[npt.NDArray]:
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
    ) -> list[Image.Image]:
        image = Image.new("RGB", (width, height), color=0)
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[npt.NDArray]:
        video = np.zeros((num_frames, width, height, 3))
        return [video] * num_videos

1108
    @abstractmethod
1109
    def _get_dummy_mm_inputs(
1110
1111
        self,
        mm_counts: Mapping[str, int],
1112
    ) -> ProcessorInputs:
1113
        """
1114
1115
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
1116
1117
1118
        """
        raise NotImplementedError

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
    def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
        mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
        supported_mm_limits = self.get_supported_mm_limits()

        mm_limits = {
            modality: mm_limit_per_prompt.get(modality, 1)
            for modality in supported_mm_limits
        }

        for modality, supported_limit in supported_mm_limits.items():
            limit = mm_limits[modality]
            if supported_limit is not None and supported_limit < limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but this model only supports "
                    f"at most {supported_limit} {modality} items.")

        return mm_limits

    def get_dummy_data(self, seq_len: int) -> DummyData:
1139
1140
1141
        # Avoid circular import
        from vllm.sequence import SequenceData

1142
1143
1144
1145
1146
1147
1148
1149
1150
        mm_counts = self._get_and_validate_dummy_mm_counts()
        mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
        if mm_counts.keys() != mm_max_tokens_per_item.keys():
            raise AssertionError(
                "The keys returned by `get_supported_mm_limits`"
                f"({set(mm_counts.keys())}) should be the same as those "
                "returned by `get_mm_max_tokens_per_item` "
                f"({set(mm_max_tokens_per_item.keys())})")

1151
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
1152
1153
1154
1155
1156
        mm_inputs = self.apply(
            prompt_text=processor_inputs.prompt_text,
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        )
1157
1158
1159
1160

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

1161
1162
1163
1164
1165
        total_placeholders_by_modality = {
            modality: sum(item["length"] for item in placeholders)
            for modality, placeholders in placeholders_by_modality.items()
        }
        expected_placeholders_by_modality = {
1166
            modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
1167
1168
1169
1170
1171
1172
1173
1174
            for modality in placeholders_by_modality
        }
        if total_placeholders_by_modality != expected_placeholders_by_modality:
            raise AssertionError(
                f"The processed dummy data has a total of "
                f"{total_placeholders_by_modality} placeholder tokens, which "
                f"is not the expected {expected_placeholders_by_modality} "
                "tokens.")
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
1187
1188
1189
1190
1191

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
1192
1193
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
1194
        )