"vllm/model_executor/models/glm4_moe.py" did not exist on "63e7176f265be43dcc425f5ab4ab45c90234f5c3"
processing.py 25 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import UserDict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from PIL.Image import Image
12
from transformers import BatchFeature, ProcessorMixin
13
from typing_extensions import assert_never
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
19
20
21
22
23

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

24
logger = init_logger(__name__)
25
26

_S = TypeVar("_S", str, list[int])
27
_PromptSeq = Union[str, list[int]]
28

29
30

@dataclass
31
32
33
class PromptReplacement:
    modality: str
    """The modality for which the replacement is made"""
34

35
36
    target: _PromptSeq
    """The text or token sequence to find and replace."""
37

38
39
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
40
    """
41
42
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
43

44
45
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
46
47
    """

48
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
49
        return _BoundPromptReplacement(
50
51
52
53
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
54
        )
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
71

72
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
73
74


75
76
77
78
79
80
81
82
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
83
84


85
86
87
88
89
90
91
92
93
94
95
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
96
97


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

113

114
class _HasModalityProp(Protocol):
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
131
132
    tokenizer: AnyTokenizer = field(repr=False)

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
159
160
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
161
162
    modality: str

163
164
165
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
166

167
168
169
170
171
172
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
173

174
175
176
177
178
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
209
    """
210
211
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
212
    """
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    @property
    def image(self) -> list[ImageItem]:
        return self["image"]

    @property
    def video(self) -> list[VideoItem]:
        return self["video"]

    @property
    def audio(self) -> list[AudioItem]:
        return self["audio"]

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.image[item_idx]

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
    """
    Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
    """
    multi_data = MultiModalDataItems()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


258
259
260
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
261
262


263
264
265
266
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
267
268
269
270
271
272
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
273
    match_len = len(match_ids)
274

275
276
    if match_len == 0:
        return
277

278
279
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
280
        end_idx = start_idx + match_len
281

282
283
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
284
285
286
287
288

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
289
290


291
292
293
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
315
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
316
317
318
319
320
321
322
323
324
325
326
327
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
328
class _PromptReplacementTextMatch(_PromptReplacementMatch):
329
330
331
332
333
334
335
336
337
338
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

339
340
341
342

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
343
    replacement: list[int]
344
345
346

    @property
    def length(self) -> int:
347
        return len(self.replacement)
348
349
350
351
352
353

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
354
355
356
357


def find_token_matches(
    prompt: list[int],
358
359
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
360
361
362
363
364
365
366
367
368
369
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
370
371
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
372
373
374
375
376
377
378
379
380
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
381
382
383
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
384
385
386
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
387
    """
388
389
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
390

391
    for match in matches:
392
393
394
395
396
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
397

398
            seen_matches[idx] = match
399
400
401
402
403
404

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
405
406
    matches: Sequence[_PromptReplacementMatch],
    mm_items: MultiModalDataItems,
407
408
409
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
410
    next_idx_by_modality = {modality: 0 for modality in mm_items}
411
412
413

    for match in _resolve_matches(prompt, matches):
        modality = match.modality
414
        modal_items = mm_items[modality]
415
416

        item_idx = next_idx_by_modality[modality]
417
        if item_idx >= len(modal_items):
418
419
420
421
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
422

423
        repl_info = match.prompt_repl
424
425
426
427
428
429
430
431
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
432
433
434
435
436
437
438
439
440
441
442

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
443
444
    matches: Sequence[_PromptReplacementTokenMatch],
    mm_items: MultiModalDataItems,
445
446
447
448
449
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

450
    token_id_seqs = _replace_matches(prompt, matches, mm_items)
451
452

    return flatten_2d_lists(token_id_seqs)
453
454


455
456
def replace_text_matches(
    prompt: str,
457
458
    matches: Sequence[_PromptReplacementTextMatch],
    mm_items: MultiModalDataItems,
459
460
461
462
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
463

464
    texts = _replace_matches(prompt, matches, mm_items)
465
466

    return "".join(texts)
467
468


469
470
471
472
473
474
475
476
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
    modal_items: list[Any],
) -> Iterable[_PlaceholderInfo]:
    if len(modal_items) == 0:
        return
477

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
                if item_index >= len(modal_items):
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
512
513
514


def iter_placeholders(
515
    prompt_repls: Sequence[_BoundPromptReplacement],
516
    prompt: list[int],
517
    mm_items: MultiModalDataItems,
518
) -> Iterable[_PlaceholderInfo]:
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

    for modality, modal_items in mm_items.items():
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
                modal_items,
            )

535

536
537
538
539
540
class ProcessorInputs(NamedTuple):
    """Keyword arguments to :meth:`BaseMultiModalProcessor`"""
    prompt_text: str
    mm_data: MultiModalDataDict
    mm_processor_kwargs: Mapping[str, object]
541
542


543
class BaseMultiModalProcessor(ABC):
544
    """
545
    Abstract base class to process multi-modal inputs to be used in vLLM.
546
547
    """

548
    def __init__(self, ctx: InputProcessingContext) -> None:
549
550
551
552
        super().__init__()

        self.ctx = ctx

553
    def __call__(
554
        self,
555
556
557
558
559
560
561
562
563
564
565
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
566
567
568
569
570
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

571
572
    @abstractmethod
    def _get_prompt_replacements(
573
        self,
574
575
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
576
        mm_processor_kwargs: Mapping[str, object],
577
578
579
580
581
582
583
584
585
586
587
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
588

589
590
    def _find_placeholders(
        self,
591
        all_prompt_repls: Sequence[_BoundPromptReplacement],
592
        new_token_ids: list[int],
593
        mm_items: MultiModalDataItems,
594
595
    ) -> list[_PlaceholderInfo]:
        return list(
596
            iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
597
598

    def _apply_hf_processor(
599
600
601
602
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
603
    ) -> BatchFeature:
604
        hf_processor = self._get_hf_processor(**mm_processor_kwargs)
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623

        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
        for k, v in mm_data.items():
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
                elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v

624
625
        assert callable(hf_processor)
        mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
626
            hf_processor,
627
            mm_processor_kwargs,
628
629
        )

630
631
632
633
634
635
636
637
638
        try:
            hf_inputs = hf_processor(
                text=prompt,  # type: ignore
                **processor_data,
                **mm_processor_kwargs,
                return_tensors="pt",
            )
        except Exception as exc:
            data = dict(text=prompt, **processor_data)
639

640
641
642
643
644
645
646
            raise RuntimeError(
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={mm_processor_kwargs}") from exc

        hf_inputs.update(passthrough_data)

        return hf_inputs
647

648
649
    def _bind_prompt_replacements(
        self,
650
651
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
652
        tokenizer = self._get_tokenizer()
653

654
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
655

656
657
    def _apply_prompt_replacements(
        self,
658
        mm_items: MultiModalDataItems,
659
660
        hf_inputs: BatchFeature,
        token_ids: list[int],
661
        prompt_repls: Sequence[_BoundPromptReplacement],
662
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
663
        tokenizer = self._get_tokenizer()
664

665
666
667
668
669
670
671
672
673
674
675
676
677
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
678
            len(matches) >= len(mm_items[modality])
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
                mm_items,
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
                mm_items,
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

702
703
        placeholders = self._find_placeholders(matched_repls, token_ids,
                                               mm_items)
704
705

        return token_ids, text, placeholders
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
726
        tokenizer = self._get_tokenizer()
727
728
729
730
731

        hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
732

733
734
735
736
        mm_items = to_multi_format(mm_data)
        prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
                                                     mm_processor_kwargs)
        all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
737

738
739
740
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
        all_placeholders = self._find_placeholders(all_prompt_repls,
741
742
                                                   prompt_ids, mm_items)

743
744
745
746
747
748
749
750
        if all_placeholders:
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
751
                mm_items,
752
753
754
755
756
757
758
759
760
                hf_inputs,
                prompt_ids,
                all_prompt_repls,
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
761
762
763

        return MultiModalInputsV2(
            type="multimodal",
764
765
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
766
767
768
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
769
770

    @abstractmethod
771
    def _get_dummy_mm_inputs(
772
773
        self,
        mm_counts: Mapping[str, int],
774
    ) -> ProcessorInputs:
775
        """
776
777
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
778
779
780
781
782
783
784
785
786
787
788
789
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
        mm_inputs = self.apply(*processor_inputs)

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = dict[str, int]()
        for modality, placeholders in placeholders_by_modality.items():
            num_placeholders = sum(item["length"] for item in placeholders)
            max_tokens = mm_max_tokens[modality]

            if num_placeholders != max_tokens:
                logger.warning(
                    "The processed dummy data has a total of %d placeholder "
                    "tokens for the '%s' modality, which is not the expected "
                    "%d tokens.", num_placeholders, modality, max_tokens)

            total_placeholders_by_modality[modality] = num_placeholders

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
820
821
822
823
824

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
825
826
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
827
        )