processing.py 25.4 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import UserDict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from PIL.Image import Image
12
from transformers import BatchFeature, ProcessorMixin
13
from typing_extensions import assert_never
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
19
20
21
22
23

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

24
logger = init_logger(__name__)
25
26

_S = TypeVar("_S", str, list[int])
27
_PromptSeq = Union[str, list[int]]
28

29
30

@dataclass
31
32
33
class PromptReplacement:
    modality: str
    """The modality for which the replacement is made"""
34

35
36
    target: _PromptSeq
    """The text or token sequence to find and replace."""
37

38
39
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
40
    """
41
42
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
43

44
45
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
46
47
    """

48
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
49
        return _BoundPromptReplacement(
50
51
52
53
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
54
        )
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
71

72
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
73
74


75
76
77
78
79
80
81
82
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
83
84


85
86
87
88
89
90
91
92
93
94
95
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
96
97


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

113

114
class _HasModalityProp(Protocol):
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
131
132
    tokenizer: AnyTokenizer = field(repr=False)

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
159
160
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
161
162
    modality: str

163
164
165
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
166

167
168
169
170
171
172
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
173

174
175
176
177
178
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
209
    """
210
211
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
212
    """
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    @property
    def image(self) -> list[ImageItem]:
        return self["image"]

    @property
    def video(self) -> list[VideoItem]:
        return self["video"]

    @property
    def audio(self) -> list[AudioItem]:
        return self["audio"]

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.image[item_idx]

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
    """
    Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
    """
    multi_data = MultiModalDataItems()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


258
259
260
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
261
262


263
264
265
266
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
267
268
269
270
271
272
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
273
    match_len = len(match_ids)
274

275
276
    if match_len == 0:
        return
277

278
279
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
280
        end_idx = start_idx + match_len
281

282
283
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
284
285
286
287
288

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
289
290


291
292
293
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
315
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
316
317
318
319
320
321
322
323
324
325
326
327
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
328
class _PromptReplacementTextMatch(_PromptReplacementMatch):
329
330
331
332
333
334
335
336
337
338
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

339
340
341
342

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
343
    replacement: list[int]
344
345
346

    @property
    def length(self) -> int:
347
        return len(self.replacement)
348
349
350
351
352
353

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
354
355
356
357


def find_token_matches(
    prompt: list[int],
358
359
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
360
361
362
363
364
365
366
367
368
369
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
370
371
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
372
373
374
375
376
377
378
379
380
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
381
382
383
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
384
385
386
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
387
    """
388
389
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
390

391
    for match in matches:
392
393
394
395
396
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
397

398
            seen_matches[idx] = match
399
400
401
402
403
404

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
405
    matches: Sequence[_PromptReplacementMatch],
406
    mm_item_counts: Mapping[str, int],
407
408
409
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
410
    next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
411
412
413
414
415

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
416
        if item_idx >= mm_item_counts[modality]:
417
418
419
420
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
421

422
        repl_info = match.prompt_repl
423
424
425
426
427
428
429
430
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
431
432
433
434
435
436
437
438
439
440
441

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
442
    matches: Sequence[_PromptReplacementTokenMatch],
443
    mm_item_counts: Mapping[str, int],
444
445
446
447
448
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

449
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
450
451

    return flatten_2d_lists(token_id_seqs)
452
453


454
455
def replace_text_matches(
    prompt: str,
456
    matches: Sequence[_PromptReplacementTextMatch],
457
    mm_item_counts: Mapping[str, int],
458
459
460
461
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
462

463
    texts = _replace_matches(prompt, matches, mm_item_counts)
464
465

    return "".join(texts)
466
467


468
469
470
471
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
472
    modal_item_count: int,
473
) -> Iterable[_PlaceholderInfo]:
474
    if modal_item_count == 0:
475
        return
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
492

493
494
495
496
497
498
499
500
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
501
                if item_index >= modal_item_count:
502
503
504
505
506
507
508
509
510
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
511
512
513


def iter_placeholders(
514
    prompt_repls: Sequence[_BoundPromptReplacement],
515
    prompt: list[int],
516
    mm_item_counts: Mapping[str, int],
517
) -> Iterable[_PlaceholderInfo]:
518
519
520
521
522
523
524
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

525
    for modality, modal_item_count in mm_item_counts.items():
526
527
528
529
530
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
531
                modal_item_count,
532
533
            )

534

535
536
537
538
539
class ProcessorInputs(NamedTuple):
    """Keyword arguments to :meth:`BaseMultiModalProcessor`"""
    prompt_text: str
    mm_data: MultiModalDataDict
    mm_processor_kwargs: Mapping[str, object]
540
541


542
class BaseMultiModalProcessor(ABC):
543
    """
544
    Abstract base class to process multi-modal inputs to be used in vLLM.
545
546
    """

547
    def __init__(self, ctx: InputProcessingContext) -> None:
548
549
550
551
        super().__init__()

        self.ctx = ctx

552
    def __call__(
553
        self,
554
555
556
557
558
559
560
561
562
563
564
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
565
566
567
568
569
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

570
571
    @abstractmethod
    def _get_prompt_replacements(
572
        self,
573
574
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
575
        mm_processor_kwargs: Mapping[str, object],
576
577
578
579
580
581
582
583
584
585
586
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
587

588
589
    def _find_placeholders(
        self,
590
        all_prompt_repls: Sequence[_BoundPromptReplacement],
591
        new_token_ids: list[int],
592
        mm_item_counts: Mapping[str, int],
593
594
    ) -> list[_PlaceholderInfo]:
        return list(
595
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
596

597
    def _get_processor_data(
598
599
        self,
        mm_data: MultiModalDataDict,
600
    ) -> BatchFeature:
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
        for k, v in mm_data.items():
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
                elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v
618
619
620
621
622
623
624
625
626
627
628
629
630
        return processor_data, passthrough_data

    def _apply_hf_processor(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # some mm_processor_kwargs may be used in processor initialization
        # instead of processor call
        hf_processor = self._get_hf_processor(**mm_processor_kwargs)

        processor_data, passthrough_data = self._get_processor_data(mm_data)
631

632
633
        assert callable(hf_processor)
        mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
634
            hf_processor,
635
            mm_processor_kwargs,
636
637
        )

638
639
640
641
642
643
644
645
646
        try:
            hf_inputs = hf_processor(
                text=prompt,  # type: ignore
                **processor_data,
                **mm_processor_kwargs,
                return_tensors="pt",
            )
        except Exception as exc:
            data = dict(text=prompt, **processor_data)
647

648
649
650
651
652
653
654
            raise RuntimeError(
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={mm_processor_kwargs}") from exc

        hf_inputs.update(passthrough_data)

        return hf_inputs
655

656
657
    def _bind_prompt_replacements(
        self,
658
659
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
660
        tokenizer = self._get_tokenizer()
661

662
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
663

664
665
666
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
667
        prompt_repls: Sequence[_BoundPromptReplacement],
668
        mm_item_counts: Mapping[str, int],
669
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
670
        tokenizer = self._get_tokenizer()
671

672
673
674
675
676
677
678
679
680
681
682
683
684
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
685
            len(matches) >= mm_item_counts[modality]
686
687
688
689
690
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
691
                mm_item_counts,
692
693
694
695
696
697
698
699
700
701
702
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
703
                mm_item_counts,
704
705
706
707
708
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

709
        placeholders = self._find_placeholders(matched_repls, token_ids,
710
                                               mm_item_counts)
711
712

        return token_ids, text, placeholders
713

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
733
        tokenizer = self._get_tokenizer()
734
735
736
737
738

        hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
739

740
741
742
743
        mm_items = to_multi_format(mm_data)
        prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
                                                     mm_processor_kwargs)
        all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
744

745
746
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
747
        mm_item_counts = {m: len(items) for m, items in mm_items.items()}
748
        all_placeholders = self._find_placeholders(all_prompt_repls,
749
                                                   prompt_ids, mm_item_counts)
750

751
752
753
754
755
756
757
758
759
760
        if all_placeholders:
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
                all_prompt_repls,
761
                mm_item_counts,
762
763
764
765
766
767
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
768
769
770

        return MultiModalInputsV2(
            type="multimodal",
771
772
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
773
774
775
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
776
777

    @abstractmethod
778
    def _get_dummy_mm_inputs(
779
780
        self,
        mm_counts: Mapping[str, int],
781
    ) -> ProcessorInputs:
782
        """
783
784
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
785
786
787
788
789
790
791
792
793
794
795
796
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
        mm_inputs = self.apply(*processor_inputs)

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = dict[str, int]()
        for modality, placeholders in placeholders_by_modality.items():
            num_placeholders = sum(item["length"] for item in placeholders)
            max_tokens = mm_max_tokens[modality]

            if num_placeholders != max_tokens:
                logger.warning(
                    "The processed dummy data has a total of %d placeholder "
                    "tokens for the '%s' modality, which is not the expected "
                    "%d tokens.", num_placeholders, modality, max_tokens)

            total_placeholders_by_modality[modality] = num_placeholders

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
827
828
829
830
831

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
832
833
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
834
        )