processing.py 27.5 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import UserDict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from PIL.Image import Image
12
from transformers import BatchFeature, ProcessorMixin
13
from typing_extensions import assert_never
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
19

20
from .audio import resample_audio
21
22
23
24
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

25
logger = init_logger(__name__)
26
27

_S = TypeVar("_S", str, list[int])
28
_PromptSeq = Union[str, list[int]]
29

30
31

@dataclass
32
33
class PromptReplacement:
    modality: str
34
    """The modality for which the replacement is made."""
35

36
37
    target: _PromptSeq
    """The text or token sequence to find and replace."""
38

39
40
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
41
    """
42
43
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
44

45
46
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
47
48
    """

49
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
50
        return _BoundPromptReplacement(
51
52
53
54
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
55
        )
56
57


58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
72

73
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
74
75


76
77
78
79
80
81
82
83
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
84
85


86
87
88
89
90
91
92
93
94
95
96
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

114

115
class _HasModalityProp(Protocol):
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
132
133
    tokenizer: AnyTokenizer = field(repr=False)

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
160
161
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
162
163
    modality: str

164
165
166
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
167

168
169
170
171
172
173
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
174

175
176
177
178
179
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
210
    """
211
212
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
213
    """
214

215
216
217
218
219
220
221
222
    @staticmethod
    def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
        """
        multi_data = MultiModalDataItems()

        for k, v in data.items():
223
224
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
225
226
227
228
            # yapf: disable
            if k == "video":
                # Special case since even a single item can be a list
                multi_data[k] = (  # type: ignore[index]
229
230
                    v if (isinstance(v, torch.Tensor)
                          or is_list_of(v, list)) else [v]
231
232
233
                )
            elif k in ("image", "audio"):
                multi_data[k] = (  # type: ignore[index]
234
                    v if isinstance(v, (torch.Tensor, list)) else [v]
235
236
237
238
239
240
241
242
243
244
245
                )
            else:
                multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
            # yapf: enable

        return multi_data

    # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
    # `self.images` doesn't update this dictionary, which may be confusing
    # We annotate the getter methods as `Sequence` to prevent others from
    # trying to update the list in this way
246
    @property
247
248
    def images(self) -> Sequence[ImageItem]:
        return self.get("image", [])
249
250

    @property
251
252
    def videos(self) -> Sequence[VideoItem]:
        return self.get("video", [])
253
254

    @property
255
256
    def audios(self) -> Sequence[AudioItem]:
        return self.get("audio", [])
257

258
259
260
    def get_item_counts(self) -> Mapping[str, int]:
        return {m: len(items) for m, items in self.items()}

261
    def get_image_size(self, item_idx: int) -> ImageSize:
262
        image = self.images[item_idx]
263
264
265
266
267
268
269
270
271

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    def get_audio_with_sr(
        self,
        item_idx: int,
        *,
        default_sr: float,
    ) -> tuple[np.ndarray, float]:
        audio = self.audios[item_idx]

        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), default_sr
        if isinstance(audio, np.ndarray):
            return audio, default_sr

        assert_never(audio)

    def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
        """
        If :code:`drop_sr=True`, the audio items in this dictionary are updated
        to be NumPy arrays which implicitly means that their sampling rate is
        the same as the model's expected sampling rate; otherwise, they remain
        as :code:`(audio, new_sr)` tuples.
        """
        if not self.audios:
            return

        new_audios = []
        for item_idx in range(len(self.audios)):
            audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
            audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
303

304
            new_audios.append(audio if drop_sr else (audio, new_sr))
305

306
        self["audio"] = new_audios
307
308


309
310
311
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
312
313


314
315
316
317
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
318
319
320
321
322
323
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
324
    match_len = len(match_ids)
325

326
327
    if match_len == 0:
        return
328

329
330
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
331
        end_idx = start_idx + match_len
332

333
334
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
335
336
337
338
339

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
340
341


342
343
344
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
366
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
367
368
369
370
371
372
373
374
375
376
377
378
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
379
class _PromptReplacementTextMatch(_PromptReplacementMatch):
380
381
382
383
384
385
386
387
388
389
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

390
391
392
393

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
394
    replacement: list[int]
395
396
397

    @property
    def length(self) -> int:
398
        return len(self.replacement)
399
400
401
402
403
404

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
405
406
407
408


def find_token_matches(
    prompt: list[int],
409
410
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
411
412
413
414
415
416
417
418
419
420
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
421
422
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
423
424
425
426
427
428
429
430
431
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
432
433
434
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
435
436
437
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
438
    """
439
440
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
441

442
    for match in matches:
443
444
445
446
447
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
448

449
            seen_matches[idx] = match
450
451
452
453
454
455

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
456
    matches: Sequence[_PromptReplacementMatch],
457
    mm_item_counts: Mapping[str, int],
458
459
460
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
461
    next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
462
463
464
465
466

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
467
        if item_idx >= mm_item_counts[modality]:
468
469
470
471
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
472

473
        repl_info = match.prompt_repl
474
475
476
477
478
479
480
481
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
482
483
484
485
486
487
488
489
490
491
492

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
493
    matches: Sequence[_PromptReplacementTokenMatch],
494
    mm_item_counts: Mapping[str, int],
495
496
497
498
499
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

500
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
501
502

    return flatten_2d_lists(token_id_seqs)
503
504


505
506
def replace_text_matches(
    prompt: str,
507
    matches: Sequence[_PromptReplacementTextMatch],
508
    mm_item_counts: Mapping[str, int],
509
510
511
512
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
513

514
    texts = _replace_matches(prompt, matches, mm_item_counts)
515
516

    return "".join(texts)
517
518


519
520
521
522
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
523
    modal_item_count: int,
524
) -> Iterable[_PlaceholderInfo]:
525
    if modal_item_count == 0:
526
        return
527

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
543

544
545
546
547
548
549
550
551
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
552
                if item_index >= modal_item_count:
553
554
555
556
557
558
559
560
561
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
562
563
564


def iter_placeholders(
565
    prompt_repls: Sequence[_BoundPromptReplacement],
566
    prompt: list[int],
567
    mm_item_counts: Mapping[str, int],
568
) -> Iterable[_PlaceholderInfo]:
569
570
571
572
573
574
575
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

576
    for modality, modal_item_count in mm_item_counts.items():
577
578
579
580
581
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
582
                modal_item_count,
583
584
            )

585

586
587
588
589
590
class ProcessorInputs(NamedTuple):
    """Keyword arguments to :meth:`BaseMultiModalProcessor`"""
    prompt_text: str
    mm_data: MultiModalDataDict
    mm_processor_kwargs: Mapping[str, object]
591
592


593
class BaseMultiModalProcessor(ABC):
594
    """
595
    Abstract base class to process multi-modal inputs to be used in vLLM.
596
597
    """

598
    def __init__(self, ctx: InputProcessingContext) -> None:
599
600
601
602
        super().__init__()

        self.ctx = ctx

603
    def __call__(
604
        self,
605
606
607
608
609
610
611
612
613
614
615
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
616
617
618
619
620
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

621
622
623
624
625
626
    def _get_mm_items(
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
        return MultiModalDataItems.from_dict(mm_data)

627
628
    @abstractmethod
    def _get_prompt_replacements(
629
        self,
630
631
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
632
        mm_processor_kwargs: Mapping[str, object],
633
634
635
636
637
638
639
640
641
642
643
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
644

645
646
    def _find_placeholders(
        self,
647
        all_prompt_repls: Sequence[_BoundPromptReplacement],
648
        new_token_ids: list[int],
649
        mm_item_counts: Mapping[str, int],
650
651
    ) -> list[_PlaceholderInfo]:
        return list(
652
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
653

654
    def _get_processor_data(
655
        self,
656
657
        mm_items: MultiModalDataItems,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
658
659
        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
660
661

        for k, v in mm_items.items():
662
663
664
665
666
667
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
668
669
                elif (is_list_of(v, torch.Tensor) and len(v) > 0
                      and v[0].ndim == 2):
670
671
672
673
674
675
676
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v
677

678
679
        return processor_data, passthrough_data

680
681
682
683
684
685
686
687
688
689
690
691
692
693
    def _call_hf_processor(
        self,
        hf_processor: ProcessorMixin,
        prompt: str,
        processor_data: Mapping[str, object],
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        return self.ctx.call_hf_processor(
            hf_processor,
            prompt,
            processor_data,
            mm_processor_kwargs,
        )

694
695
696
    def _apply_hf_processor(
        self,
        prompt: str,
697
        mm_items: MultiModalDataItems,
698
699
700
701
702
703
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # some mm_processor_kwargs may be used in processor initialization
        # instead of processor call
        hf_processor = self._get_hf_processor(**mm_processor_kwargs)

704
        processor_data, passthrough_data = self._get_processor_data(mm_items)
705

706
        hf_inputs = self._call_hf_processor(
707
            hf_processor,
708
709
710
            prompt=prompt,
            processor_data=processor_data,
            mm_processor_kwargs=mm_processor_kwargs,
711
        )
712
713
714
        hf_inputs.update(passthrough_data)

        return hf_inputs
715

716
717
    def _bind_prompt_replacements(
        self,
718
719
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
720
        tokenizer = self._get_tokenizer()
721

722
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
723

724
725
726
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
727
        prompt_repls: Sequence[_BoundPromptReplacement],
728
        mm_item_counts: Mapping[str, int],
729
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
730
        tokenizer = self._get_tokenizer()
731

732
733
734
735
736
737
738
739
740
741
742
743
744
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
745
            len(matches) >= mm_item_counts[modality]
746
747
748
749
750
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
751
                mm_item_counts,
752
753
754
755
756
757
758
759
760
761
762
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
763
                mm_item_counts,
764
765
766
767
768
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

769
        placeholders = self._find_placeholders(matched_repls, token_ids,
770
                                               mm_item_counts)
771
772

        return token_ids, text, placeholders
773

774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
793
        mm_items = self._get_mm_items(mm_data)
794

795
        hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
796
797
798
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
799

800
801
802
        prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
                                                     mm_processor_kwargs)
        all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
803

804
805
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
806
        mm_item_counts = mm_items.get_item_counts()
807
        all_placeholders = self._find_placeholders(all_prompt_repls,
808
                                                   prompt_ids, mm_item_counts)
809

810
        if all_placeholders:
811
            tokenizer = self._get_tokenizer()
812
813
814
815
816
817
818
819
820
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
                all_prompt_repls,
821
                mm_item_counts,
822
823
824
825
826
827
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
828
829
830

        return MultiModalInputsV2(
            type="multimodal",
831
832
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
833
834
835
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
836
837

    @abstractmethod
838
    def _get_dummy_mm_inputs(
839
840
        self,
        mm_counts: Mapping[str, int],
841
    ) -> ProcessorInputs:
842
        """
843
844
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
845
846
847
848
849
850
851
852
853
854
855
856
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
        mm_inputs = self.apply(*processor_inputs)

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = dict[str, int]()
        for modality, placeholders in placeholders_by_modality.items():
            num_placeholders = sum(item["length"] for item in placeholders)
            max_tokens = mm_max_tokens[modality]

            if num_placeholders != max_tokens:
                logger.warning(
                    "The processed dummy data has a total of %d placeholder "
                    "tokens for the '%s' modality, which is not the expected "
                    "%d tokens.", num_placeholders, modality, max_tokens)

            total_placeholders_by_modality[modality] = num_placeholders

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
887
888
889
890
891

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
892
893
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
894
        )