processing.py 25.1 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import UserDict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from PIL.Image import Image
12
from transformers import BatchFeature, ProcessorMixin
13
from typing_extensions import assert_never
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
19
20
21
22
23

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

24
logger = init_logger(__name__)
25
26

_S = TypeVar("_S", str, list[int])
27
_PromptSeq = Union[str, list[int]]
28

29
30

@dataclass
31
32
33
class PromptReplacement:
    modality: str
    """The modality for which the replacement is made"""
34

35
36
    target: _PromptSeq
    """The text or token sequence to find and replace."""
37

38
39
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
40
    """
41
42
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
43

44
45
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
46
47
    """

48
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
49
        return _BoundPromptReplacement(
50
51
52
53
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
54
        )
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
71

72
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
73
74


75
76
77
78
79
80
81
82
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
83
84


85
86
87
88
89
90
91
92
93
94
95
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
96
97


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

113

114
class _HasModalityProp(Protocol):
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
131
132
    tokenizer: AnyTokenizer = field(repr=False)

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
159
160
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
161
162
    modality: str

163
164
165
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
166

167
168
169
170
171
172
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
173

174
175
176
177
178
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
209
    """
210
211
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
212
    """
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    @property
    def image(self) -> list[ImageItem]:
        return self["image"]

    @property
    def video(self) -> list[VideoItem]:
        return self["video"]

    @property
    def audio(self) -> list[AudioItem]:
        return self["audio"]

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.image[item_idx]

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
    """
    Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
    """
    multi_data = MultiModalDataItems()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


258
259
260
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
261
262


263
264
265
266
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
267
268
269
270
271
272
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
273
    match_len = len(match_ids)
274

275
276
    if match_len == 0:
        return
277

278
279
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
280
        end_idx = start_idx + match_len
281

282
283
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
284
285
286
287
288

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
289
290


291
292
293
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
315
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
316
317
318
319
320
321
322
323
324
325
326
327
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
328
class _PromptReplacementTextMatch(_PromptReplacementMatch):
329
330
331
332
333
334
335
336
337
338
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

339
340
341
342

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
343
    replacement: list[int]
344
345
346

    @property
    def length(self) -> int:
347
        return len(self.replacement)
348
349
350
351
352
353

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
354
355
356
357


def find_token_matches(
    prompt: list[int],
358
359
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
360
361
362
363
364
365
366
367
368
369
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
370
371
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
372
373
374
375
376
377
378
379
380
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
381
382
383
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
384
385
386
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
387
    """
388
389
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
390

391
    for match in matches:
392
393
394
395
396
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
397

398
            seen_matches[idx] = match
399
400
401
402
403
404

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
405
    matches: Sequence[_PromptReplacementMatch],
406
    mm_item_counts: Mapping[str, int],
407
408
409
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
410
    next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
411
412
413
414
415

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
416
        if item_idx >= mm_item_counts[modality]:
417
418
419
420
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
421

422
        repl_info = match.prompt_repl
423
424
425
426
427
428
429
430
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
431
432
433
434
435
436
437
438
439
440
441

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
442
    matches: Sequence[_PromptReplacementTokenMatch],
443
    mm_item_counts: Mapping[str, int],
444
445
446
447
448
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

449
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
450
451

    return flatten_2d_lists(token_id_seqs)
452
453


454
455
def replace_text_matches(
    prompt: str,
456
    matches: Sequence[_PromptReplacementTextMatch],
457
    mm_item_counts: Mapping[str, int],
458
459
460
461
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
462

463
    texts = _replace_matches(prompt, matches, mm_item_counts)
464
465

    return "".join(texts)
466
467


468
469
470
471
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
472
    modal_item_count: int,
473
) -> Iterable[_PlaceholderInfo]:
474
    if modal_item_count == 0:
475
        return
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
492

493
494
495
496
497
498
499
500
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
501
                if item_index >= modal_item_count:
502
503
504
505
506
507
508
509
510
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
511
512
513


def iter_placeholders(
514
    prompt_repls: Sequence[_BoundPromptReplacement],
515
    prompt: list[int],
516
    mm_item_counts: Mapping[str, int],
517
) -> Iterable[_PlaceholderInfo]:
518
519
520
521
522
523
524
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

525
    for modality, modal_item_count in mm_item_counts.items():
526
527
528
529
530
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
531
                modal_item_count,
532
533
            )

534

535
536
537
538
539
class ProcessorInputs(NamedTuple):
    """Keyword arguments to :meth:`BaseMultiModalProcessor`"""
    prompt_text: str
    mm_data: MultiModalDataDict
    mm_processor_kwargs: Mapping[str, object]
540
541


542
class BaseMultiModalProcessor(ABC):
543
    """
544
    Abstract base class to process multi-modal inputs to be used in vLLM.
545
546
    """

547
    def __init__(self, ctx: InputProcessingContext) -> None:
548
549
550
551
        super().__init__()

        self.ctx = ctx

552
    def __call__(
553
        self,
554
555
556
557
558
559
560
561
562
563
564
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
565
566
567
568
569
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

570
571
    @abstractmethod
    def _get_prompt_replacements(
572
        self,
573
574
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
575
        mm_processor_kwargs: Mapping[str, object],
576
577
578
579
580
581
582
583
584
585
586
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
587

588
589
    def _find_placeholders(
        self,
590
        all_prompt_repls: Sequence[_BoundPromptReplacement],
591
        new_token_ids: list[int],
592
        mm_item_counts: Mapping[str, int],
593
594
    ) -> list[_PlaceholderInfo]:
        return list(
595
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
596
597

    def _apply_hf_processor(
598
599
600
601
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
602
    ) -> BatchFeature:
603
        hf_processor = self._get_hf_processor(**mm_processor_kwargs)
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
        for k, v in mm_data.items():
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
                elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v

623
624
        assert callable(hf_processor)
        mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
625
            hf_processor,
626
            mm_processor_kwargs,
627
628
        )

629
630
631
632
633
634
635
636
637
        try:
            hf_inputs = hf_processor(
                text=prompt,  # type: ignore
                **processor_data,
                **mm_processor_kwargs,
                return_tensors="pt",
            )
        except Exception as exc:
            data = dict(text=prompt, **processor_data)
638

639
640
641
642
643
644
645
            raise RuntimeError(
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={mm_processor_kwargs}") from exc

        hf_inputs.update(passthrough_data)

        return hf_inputs
646

647
648
    def _bind_prompt_replacements(
        self,
649
650
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
651
        tokenizer = self._get_tokenizer()
652

653
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
654

655
656
657
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
658
        prompt_repls: Sequence[_BoundPromptReplacement],
659
        mm_item_counts: Mapping[str, int],
660
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
661
        tokenizer = self._get_tokenizer()
662

663
664
665
666
667
668
669
670
671
672
673
674
675
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
676
            len(matches) >= mm_item_counts[modality]
677
678
679
680
681
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
682
                mm_item_counts,
683
684
685
686
687
688
689
690
691
692
693
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
694
                mm_item_counts,
695
696
697
698
699
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

700
        placeholders = self._find_placeholders(matched_repls, token_ids,
701
                                               mm_item_counts)
702
703

        return token_ids, text, placeholders
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
724
        tokenizer = self._get_tokenizer()
725
726
727
728
729

        hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
730

731
732
733
734
        mm_items = to_multi_format(mm_data)
        prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
                                                     mm_processor_kwargs)
        all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
735

736
737
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
738
        mm_item_counts = {m: len(items) for m, items in mm_items.items()}
739
        all_placeholders = self._find_placeholders(all_prompt_repls,
740
                                                   prompt_ids, mm_item_counts)
741

742
743
744
745
746
747
748
749
750
751
        if all_placeholders:
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
                all_prompt_repls,
752
                mm_item_counts,
753
754
755
756
757
758
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
759
760
761

        return MultiModalInputsV2(
            type="multimodal",
762
763
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
764
765
766
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
767
768

    @abstractmethod
769
    def _get_dummy_mm_inputs(
770
771
        self,
        mm_counts: Mapping[str, int],
772
    ) -> ProcessorInputs:
773
        """
774
775
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
776
777
778
779
780
781
782
783
784
785
786
787
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
        mm_inputs = self.apply(*processor_inputs)

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = dict[str, int]()
        for modality, placeholders in placeholders_by_modality.items():
            num_placeholders = sum(item["length"] for item in placeholders)
            max_tokens = mm_max_tokens[modality]

            if num_placeholders != max_tokens:
                logger.warning(
                    "The processed dummy data has a total of %d placeholder "
                    "tokens for the '%s' modality, which is not the expected "
                    "%d tokens.", num_placeholders, modality, max_tokens)

            total_placeholders_by_modality[modality] = num_placeholders

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
818
819
820
821
822

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
823
824
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
825
        )