"vscode:/vscode.git/clone" did not exist on "5e714f7ff4166d62b5f923766a0268a1758f2a61"
processing.py 33.8 KB
Newer Older
1
import pickle
2
3
4
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from blake3 import blake3
12
from PIL.Image import Image
13
from transformers import BatchFeature, ProcessorMixin
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of
19

20
21
22
from .inputs import (MultiModalDataDict, MultiModalDataItems,
                     MultiModalFieldConfig, MultiModalFieldItem,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange)
23

24
logger = init_logger(__name__)
25
26

_S = TypeVar("_S", str, list[int])
27
_PromptSeq = Union[str, list[int]]
28

29
30

@dataclass
31
32
class PromptReplacement:
    modality: str
33
    """The modality for which the replacement is made."""
34

35
36
    target: _PromptSeq
    """The text or token sequence to find and replace."""
37

38
39
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
40
    """
41
42
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
43

44
45
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
46
47
    """

48
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
49
        return _BoundPromptReplacement(
50
51
52
53
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
54
        )
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
71

72
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
73
74


75
76
77
78
79
80
81
82
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
83
84


85
86
87
88
89
90
91
92
93
94
95
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
96
97


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

113

114
class _HasModalityProp(Protocol):
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
131
132
    tokenizer: AnyTokenizer = field(repr=False)

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
159
160
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
161
162
    modality: str

163
164
165
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
166

167
168
169
170
171
172
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
173

174
175
176
177
178
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


203
204
205
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
206
207


208
209
210
211
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
212
213
214
215
216
217
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
218
    match_len = len(match_ids)
219

220
221
    if match_len == 0:
        return
222

223
224
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
225
        end_idx = start_idx + match_len
226

227
228
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
229
230
231
232
233

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
234
235


236
237
238
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
260
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
261
262
263
264
265
266
267
268
269
270
271
272
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
273
class _PromptReplacementTextMatch(_PromptReplacementMatch):
274
275
276
277
278
279
280
281
282
283
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

284
285
286
287

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
288
    replacement: list[int]
289
290
291

    @property
    def length(self) -> int:
292
        return len(self.replacement)
293
294
295
296
297
298

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
299
300
301
302


def find_token_matches(
    prompt: list[int],
303
304
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
305
306
307
308
309
310
311
312
313
314
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
315
316
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
317
318
319
320
321
322
323
324
325
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
326
327
328
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
329
330
331
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
332
    """
333
334
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
335

336
    for match in matches:
337
338
339
340
341
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
342

343
            seen_matches[idx] = match
344
345
346
347
348
349

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
350
    matches: Sequence[_PromptReplacementMatch],
351
    mm_item_counts: Mapping[str, int],
352
353
354
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
355
    next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
356
357
358
359
360

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
361
        if item_idx >= mm_item_counts[modality]:
362
363
364
365
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
366

367
        repl_info = match.prompt_repl
368
369
370
371
372
373
374
375
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
376
377
378
379
380
381
382
383
384
385
386

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
387
    matches: Sequence[_PromptReplacementTokenMatch],
388
    mm_item_counts: Mapping[str, int],
389
390
391
392
393
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

394
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
395
396

    return flatten_2d_lists(token_id_seqs)
397
398


399
400
def replace_text_matches(
    prompt: str,
401
    matches: Sequence[_PromptReplacementTextMatch],
402
    mm_item_counts: Mapping[str, int],
403
404
405
406
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
407

408
    texts = _replace_matches(prompt, matches, mm_item_counts)
409
410

    return "".join(texts)
411
412


413
414
415
416
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
417
    modal_item_count: int,
418
) -> Iterable[_PlaceholderInfo]:
419
    if modal_item_count == 0:
420
        return
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
437

438
439
440
441
442
443
444
445
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
446
                if item_index >= modal_item_count:
447
448
449
450
451
452
453
454
455
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
456
457
458


def iter_placeholders(
459
    prompt_repls: Sequence[_BoundPromptReplacement],
460
    prompt: list[int],
461
    mm_item_counts: Mapping[str, int],
462
) -> Iterable[_PlaceholderInfo]:
463
464
465
466
467
468
469
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

470
    for modality, modal_item_count in mm_item_counts.items():
471
472
473
474
475
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
476
                modal_item_count,
477
478
            )

479

480
481
482
@dataclass
class ProcessorInputs:
    """Keyword arguments to :meth:`BaseMultiModalProcessor`."""
483
484
    prompt_text: str
    mm_data: MultiModalDataDict
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)


class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

        self._cache = LRUCache[str, Mapping[str,
                                            MultiModalFieldItem]](capacity)

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def _serialize_item(self, obj: object) -> bytes:
        # Simple cases
        if isinstance(obj, str):
            return obj.encode("utf-8")
        if isinstance(obj, bytes):
            return obj
        if isinstance(obj, Image):
            return obj.tobytes()

        # Convertible to NumPy arrays
        if isinstance(obj, torch.Tensor):
            obj = obj.numpy()
        if isinstance(obj, (int, float)):
            obj = np.array(obj)
        if isinstance(obj, np.ndarray):
            return obj.tobytes()

        logger.warning(
            "No serialization method found for %s. "
            "Falling back to pickle.", type(obj))

        return pickle.dumps(obj)

    def _item_to_bytes(
        self,
        key: str,
        obj: object,
    ) -> Iterable[tuple[bytes, bytes]]:
        # Recursive cases
        if isinstance(obj, (list, tuple)):
            for i, elem in enumerate(obj):
                yield from self._item_to_bytes(f"{key}.{i}", elem)
        elif isinstance(obj, dict):
            for k, v in obj.items():
                yield from self._item_to_bytes(f"{key}.{k}", v)
        else:
            key_bytes = self._serialize_item(key)
            value_bytes = self._serialize_item(obj)
            yield key_bytes, value_bytes

    def _hash_kwargs(self, **kwargs: object) -> str:
        hasher = blake3()

        for k, v in kwargs.items():
            for k_bytes, v_bytes in self._item_to_bytes(k, v):
                hasher.update(k_bytes)
                hasher.update(v_bytes)

        return hasher.hexdigest()

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
    ) -> Optional[Mapping[str, MultiModalFieldItem]]:
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

        cache_key = self._hash_kwargs(model_id=model_id,
                                      **{modality: input_item},
                                      **input_kwargs)
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
        output_kwargs: Mapping[str, MultiModalFieldItem],
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
        cache_key = self._hash_kwargs(model_id=model_id,
                                      **{modality: input_item},
                                      **input_kwargs)
        self._cache.put(cache_key, output_kwargs)
598
599


600
class BaseMultiModalProcessor(ABC):
601
    """
602
    Abstract base class to process multi-modal inputs to be used in vLLM.
603
604
    """

605
606
607
608
609
    def __init__(self,
                 ctx: InputProcessingContext,
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
610
611
612
        super().__init__()

        self.ctx = ctx
613
614
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
615

616
    def __call__(
617
        self,
618
619
        prompt: str,
        mm_data: MultiModalDataDict,
620
        hf_processor_mm_kwargs: Mapping[str, object],
621
    ) -> MultiModalInputsV2:
622
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
623
624
625
626
627
628

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
629
630
631
632
633
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

634
635
636
637
638
639
    def _get_mm_items(
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
        return MultiModalDataItems.from_dict(mm_data)

640
641
642
643
644
645
646
647
648
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

649
650
    @abstractmethod
    def _get_prompt_replacements(
651
        self,
652
        mm_items: MultiModalDataItems,
653
654
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
655
656
657
658
659
660
661
662
663
664
665
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
666

667
668
    def _find_placeholders(
        self,
669
        all_prompt_repls: Sequence[_BoundPromptReplacement],
670
        new_token_ids: list[int],
671
        mm_item_counts: Mapping[str, int],
672
673
    ) -> list[_PlaceholderInfo]:
        return list(
674
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
675

676
    def _get_hf_mm_data(
677
        self,
678
679
        mm_items: MultiModalDataItems,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
680
681
        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
682
683

        for k, v in mm_items.items():
684
685
686
687
688
689
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
690
691
                elif (is_list_of(v, torch.Tensor) and len(v) > 0
                      and v[0].ndim == 2):
692
693
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
694
                elif len(v) > 0:
695
696
697
698
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v
699

700
701
        return processor_data, passthrough_data

702
703
704
    def _call_hf_processor(
        self,
        prompt: str,
705
706
707
708
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
709
710
    ) -> BatchFeature:
        return self.ctx.call_hf_processor(
711
712
713
            self._get_hf_processor(**mm_kwargs),
            dict(text=prompt, **mm_data),
            mm_kwargs,
714
715
        )

716
717
    def _apply_hf_processor(
        self,
718
        prompt_text: str,
719
        mm_items: MultiModalDataItems,
720
721
722
723
724
725
726
727
728
729
730
731
732
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text and multi-modal data.
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
733

734
        prompt_ids, = processed_data.pop("input_ids").tolist()
735

736
737
738
739
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
            enable_sanity_checks=self.enable_sanity_checks,
740
        )
741

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        return prompt_ids, mm_kwargs

    def _apply_hf_processor_missing(
        self,
        prompt_text: str,
        mm_missing_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ):
        """
        Apply the HF processor on the full prompt text, but only on the
        multi-modal data that are missing from the cache.

        Note: We pass prompt text and multi-modal data into the HF processor
        in separate calls to avoid HF prompt replacement being done for
        cached items; instead, we rely on our own prompt replacement logic
        for the full text.
        """
        mm_missing_counts = mm_missing_data_items.get_item_counts()

        prompt_ids, _ = self._apply_hf_processor(
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

        # Some HF processors (e.g. Qwen2-VL) expect corresponding
        # multi-modal tokens to be in the prompt text
        dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts)

        _, mm_missing_kwargs = self._apply_hf_processor(
            prompt_text=dummy_inputs.prompt_text,
            mm_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
        prompt_text: str,
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
        model_id = self.ctx.model_config.model

        if cache is None or mm_data_items.has_embedding_inputs():
            return self._apply_hf_processor(
                prompt_text=prompt_text,
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            )

        mm_maybe_cached_field_items = {
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
            modality: [idx for idx, out in enumerate(fields) if out is None]
            for modality, fields in mm_maybe_cached_field_items.items()
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
        mm_missing_data_items = self._get_mm_items(mm_missing_data)

        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
            prompt_text=prompt_text,
            mm_missing_data_items=mm_missing_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

        mm_merged_field_items = dict[str, list[Mapping[str,
                                                       MultiModalFieldItem]]]()
        for modality, modal_items_lst in mm_maybe_cached_field_items.items():
            merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()

            for idx, modal_items in enumerate(modal_items_lst):
                if modal_items is None:
                    modal_items = mm_missing_kwargs.get_items_by_modality(
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
                        modal_items,
                    )

                    mm_missing_next_idx[modality] += 1

                merged_modal_items_lst.append(modal_items)

            mm_merged_field_items[modality] = merged_modal_items_lst

        if self.enable_sanity_checks:
            mm_missing_counts = mm_missing_data_items.get_item_counts()
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

        mm_kwargs = MultiModalKwargs.from_items_by_modality(
            mm_merged_field_items,
            enable_sanity_checks=self.enable_sanity_checks,
        )

        if self.enable_sanity_checks:
            mm_item_counts = mm_data_items.get_item_counts()

            for modality, item_count in mm_item_counts.items():
                for item_idx in range(item_count):
                    try:
                        mm_kwargs.get_items_by_modality(modality, item_idx)
                    except Exception as e:
                        # Make it easy to set a breakpoint in the debugger
                        raise e

        return prompt_ids, mm_kwargs
879

880
881
    def _bind_prompt_replacements(
        self,
882
883
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
884
        tokenizer = self._get_tokenizer()
885

886
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
887

888
889
890
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
891
        prompt_repls: Sequence[_BoundPromptReplacement],
892
        mm_item_counts: Mapping[str, int],
893
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
894
        tokenizer = self._get_tokenizer()
895

896
        token_matches = find_token_matches(token_ids, prompt_repls)
897
898
899
900
        mm_match_counts = {
            modality: len(matches)
            for modality, matches in full_groupby_modality(token_matches)
        }
901
902
903
904
905
906
907
908
909
910
911
912

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
913
914
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
915
916
917
918
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
919
                mm_item_counts,
920
921
922
923
924
925
926
927
928
929
930
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
931
                mm_item_counts,
932
933
934
935
936
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

937
        placeholders = self._find_placeholders(matched_repls, token_ids,
938
                                               mm_item_counts)
939
940

        return token_ids, text, placeholders
941

942
943
944
945
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
946
        hf_processor_mm_kwargs: Mapping[str, object],
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
961
        mm_items = self._get_mm_items(mm_data)
962

963
964
965
966
967
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
            prompt_text,
            mm_items,
            hf_processor_mm_kwargs,
        )
968

969
970
971
972
973
974
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
        prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)
975

976
977
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
978
        mm_item_counts = mm_items.get_item_counts()
979
980
        all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
                                                   mm_item_counts)
981

982
        if all_placeholders:
983
            tokenizer = self._get_tokenizer()
984
985
986
987
988
989
990
991
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
992
                prompt_repls,
993
                mm_item_counts,
994
995
996
997
998
999
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
1000
1001
1002

        return MultiModalInputsV2(
            type="multimodal",
1003
1004
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
1005
1006
1007
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
1008
1009

    @abstractmethod
1010
    def _get_dummy_mm_inputs(
1011
1012
        self,
        mm_counts: Mapping[str, int],
1013
    ) -> ProcessorInputs:
1014
        """
1015
1016
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

1029
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
1030
1031
1032
1033
1034
        mm_inputs = self.apply(
            prompt_text=processor_inputs.prompt_text,
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        )
1035
1036
1037
1038

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        total_placeholders_by_modality = {
            modality: sum(item["length"] for item in placeholders)
            for modality, placeholders in placeholders_by_modality.items()
        }
        expected_placeholders_by_modality = {
            modality: mm_max_tokens[modality]
            for modality in placeholders_by_modality
        }
        if total_placeholders_by_modality != expected_placeholders_by_modality:
            raise AssertionError(
                f"The processed dummy data has a total of "
                f"{total_placeholders_by_modality} placeholder tokens, which "
                f"is not the expected {expected_placeholders_by_modality} "
                "tokens.")
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
1065
1066
1067
1068
1069

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
1070
1071
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
1072
        )