processing.py 25.1 KB
Newer Older
1
2
3
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
4
from dataclasses import dataclass
5
from functools import lru_cache
6
7
from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union, cast)
8

9
10
import torch
from transformers import BatchFeature, ProcessorMixin
11
from typing_extensions import TypeAlias, TypedDict
12

13
from vllm.inputs import DummyData, InputProcessingContext
14
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
15
16
from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of,
                        resolve_mm_processor_kwargs)
17
18
19
20
21

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

def bind_prompt_sequence(
    seq: Union[str, list[int]],
    tokenizer: AnyTokenizer,
) -> "_BoundPromptSequence":
    """
    Bind a text or token sequence to a tokenizer so that it can be
    lazily converted into the other format on demand.
    """
    return _BoundPromptSequence(
        tokenizer=tokenizer,
        _text=seq if isinstance(seq, str) else None,
        _token_ids=seq if isinstance(seq, list) else None,
    )


38
_T = TypeVar("_T")
39
_S = TypeVar("_S", str, list[int])
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

@dataclass
class PromptReplacement(Generic[_S, _T]):
    target: _S
    """The text or token sequence to find and replace."""

    repl_unit: _S
    """
    The unit making up the replacement text or token sequence.
    
    See :code:`repl_count` for more details.
    """

    repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
    """
    Given the original multi-modal items for this modality, HF-processed data,
    and index of the processed item, output the number of repetitions of
    :code:`repl_unit` to build up the replacement text or token sequence.

    For convenience, you can pass in an integer if the number of repetitions is
    a constant.
    """

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(target={self.target!r}, "
                f"repl_unit={self.repl_unit!r})")

    def bind(
        self,
        modality: str,
        tokenizer: AnyTokenizer,
    ) -> "_BoundPromptReplacement[_T]":
        return _BoundPromptReplacement(
            modality=modality,
            target=bind_prompt_sequence(self.target, tokenizer),
            repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
            repl_count=self.repl_count,
        )
79
80
81
82


@dataclass
class ModalityProcessingMetadata(Generic[_T]):
83
84
    prompt_repls: Sequence[Union[PromptReplacement[str, _T],
                                 PromptReplacement[list[int], _T]]]
85
    """
86
87
88
89
    Defines each text or token sequence to replace in the HF-processed prompt.

    This is skipped if the HF-processed prompt is found to already contain
    the replacement prompts.
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    """


class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: ModalityProcessingMetadata[ImageItem]
    video: ModalityProcessingMetadata[VideoItem]
    audio: ModalityProcessingMetadata[AudioItem]


MultiModalProcessingMetadata: TypeAlias = \
    Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
"""


114
115
116
117
118
119
120
121
122
123
124
125
126
127
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
128

129
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
130
131


132
133
134
135
136
137
138
139
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
140
141


142
143
144
145
146
147
148
149
150
151
152
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
153
154


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

170

171
class _HasModalityProp(Protocol):
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
    tokenizer: AnyTokenizer
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(_text={self._text!r}, "
                f"_token_ids={self._token_ids!r})")


@dataclass
class _BoundPromptReplacement(Generic[_T]):
    modality: str
    target: _BoundPromptSequence
    repl_unit: _BoundPromptSequence
    repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]

    def get_count(
        self,
        mm_items: list[_T],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> int:
        repl_count = self.repl_count
        if isinstance(repl_count, int):
            return repl_count

        return repl_count(mm_items, hf_inputs, item_idx)


def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
239
240
241
242
243
    """
    Convert a :class:`MultiModalDataDict` containing single data items
    to a :class:`MultiModalMultiDataDict` containing multiple data items
    per entry.
    """
244
    multi_data = dict[str, list[Any]]()
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


260
261
262
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
263
264


265
266
267
268
269
270
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
    """Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
    match_len = len(match_ids)
271

272
273
274
275
    last_end_idx = 0
    for start_idx in range(len(token_ids) - match_len + 1):
        if start_idx < last_end_idx:
            continue  # Exclude overlapping matches
276

277
278
279
280
        end_idx = start_idx + match_len
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
            last_end_idx = end_idx
281
282


283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
    prompt_repl: _BoundPromptReplacement[_T]

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

300
    @property
301
    @abstractmethod
302
    def repl_unit(self) -> _S:
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
    prompt_repl: _BoundPromptReplacement[_T]
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx

323
324
325
    @property
    def repl_unit(self) -> list[int]:
        return self.prompt_repl.repl_unit.token_ids
326

327
328
329
330
331
332
333
334
335
336
337
338
339
340

@dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
    prompt_repl: _BoundPromptReplacement[_T]
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    @property
    def repl_unit(self) -> str:
        return self.prompt_repl.repl_unit.text


class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
    unit: list[int]
    unit_count: int

    @property
    def length(self) -> int:
        return len(self.unit) * self.unit_count

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393


def find_token_matches(
    prompt: list[int],
    prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTokenMatch[_T]]:
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
    prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTextMatch[_T]]:
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
    prompt: _S,
    matches: Sequence[_PromptReplacementMatch[_T, _S]],
) -> list[_PromptReplacementMatch[_T, _S]]:
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
394
    """
395
396
397
    seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
        = [None] * len(prompt)

398
    for match in matches:
399
400
401
402
403
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
404

405
            seen_matches[idx] = match
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
    matches: Sequence[_PromptReplacementMatch[_T, _S]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
    next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}

    for match in _resolve_matches(prompt, matches):
        modality = match.modality
        mm_items = mm_items_by_modality[modality]

        item_idx = next_idx_by_modality[modality]
        if item_idx >= len(mm_items):
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
430
431
432
        repl_unit = match.repl_unit
        repl_info = match.prompt_repl
        repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
433

434
435
        out_seqs.append(prompt[prev_end_idx:start_idx] +
                        repl_unit * repl_count)
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
    matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

    token_id_seqs = _replace_matches(
        prompt,
        matches,
        mm_items_by_modality,
        hf_inputs,
    )

    return flatten_2d_lists(token_id_seqs)
462
463


464
465
466
467
468
469
470
471
472
def replace_text_matches(
    prompt: str,
    matches: Sequence[_PromptReplacementMatch[_T, str]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
473

474
475
476
477
478
479
480
481
    texts = _replace_matches(
        prompt,
        matches,
        mm_items_by_modality,
        hf_inputs,
    )

    return "".join(texts)
482
483


484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def _merge_placeholder_matches(
    matches: Iterable[_PromptReplacementTokenMatch],
) -> Iterable[_PromptReplacementTokenMatch]:
    current_match = None

    for match in sorted(matches, key=lambda x: x.start_idx):
        if current_match is None:
            current_match = match
        elif (current_match.prompt_repl == match.prompt_repl
              and current_match.end_idx == match.start_idx):
            current_match = _PromptReplacementTokenMatch(
                current_match.prompt_repl,
                match=_TokenMatch(current_match.start_idx, match.end_idx),
            )
        else:
            yield current_match
            current_match = match

    if current_match is not None:
        yield current_match


def iter_placeholders(
    prompt_repls: Sequence[_BoundPromptReplacement[Any]],
    prompt: list[int],
    *,
    min_unit_count: int = 1,
) -> Iterable[_PlaceholderInfo]:
    """Yield each set of placeholder tokens found in :code:`token_ids`."""
    if min_unit_count <= 0:
        raise ValueError("`min_unit_count` must be a positive integer")

    matches = (_PromptReplacementTokenMatch(prompt_repl, match)
               for prompt_repl in prompt_repls
               if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
               for match in iter_token_matches(prompt, repl_unit))

    for match in _merge_placeholder_matches(matches):
        unit = match.repl_unit
        placeholder = _PlaceholderInfo(
            modality=match.modality,
            start_idx=match.start_idx,
            unit=unit,
            unit_count=(match.end_idx - match.start_idx) // len(unit),
        )

        if placeholder.unit_count >= min_unit_count:
            yield placeholder


534
class BaseMultiModalProcessor(ABC):
535
    """
536
    Abstract base class to process multi-modal inputs to be used in vLLM.
537
538
539
540
541
542
543
544
545
546
547
    """

    def __init__(
        self,
        ctx: InputProcessingContext,
        metadata: MultiModalProcessingMetadata,
    ) -> None:
        super().__init__()

        self.ctx = ctx
        self.metadata = metadata
548
549
        self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
                                         or {})
550

551
552
553
554
555
    def _get_hf_processor(
        self,
        **mm_processor_kwargs: Mapping[str, object],
    ) -> ProcessorMixin:
        # by default, we won't pass any kwargs to the processor initialization
556
557
558
559
560
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

561
562
563
564
565
566
567
568
    def __call__(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

569
570
571
572
573
574
575
576
    def _find_placeholders(
        self,
        all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
        new_token_ids: list[int],
        *,
        # To avoid false positives from multi-input when detecting
        # whether placeholder tokens have been inserted, in case
        # the target sequence is a subset of the replacement tokens
577
        min_unit_count: int = 16,
578
579
580
581
582
    ) -> list[_PlaceholderInfo]:
        return list(
            iter_placeholders(
                all_prompt_repls,
                new_token_ids,
583
                min_unit_count=min_unit_count,
584
585
586
            ))

    def _apply_hf_processor(
587
588
589
590
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
591
    ) -> BatchFeature:
592
593
594
595
596
597
598
        # some mm_processor_kwargs may be used in processor initialization
        # instead of processor call
        processor_init_kwargs = {
            **self.init_mm_processor_kwargs,
            **mm_processor_kwargs,
        }
        hf_processor = self._get_hf_processor(**processor_init_kwargs)
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617

        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
        for k, v in mm_data.items():
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
                elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v

618
619
620
621
622
623
624
        # filter mm_processor_kwargs used in processor call
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            self.init_mm_processor_kwargs,
            cast(Dict[str, Any], mm_processor_kwargs),
            hf_processor,
        )

625
626
627
628
629
630
631
632
633
        try:
            hf_inputs = hf_processor(
                text=prompt,  # type: ignore
                **processor_data,
                **mm_processor_kwargs,
                return_tensors="pt",
            )
        except Exception as exc:
            data = dict(text=prompt, **processor_data)
634

635
636
637
638
639
640
641
            raise RuntimeError(
                f"Failed to apply {type(hf_processor).__name__} "
                f"on data={data} with kwargs={mm_processor_kwargs}") from exc

        hf_inputs.update(passthrough_data)

        return hf_inputs
642

643
644
645
646
    def _bind_prompt_replacements(
        self,
        mm_data: MultiModalDataDict,
    ) -> list[_BoundPromptReplacement[Any]]:
647
        tokenizer = self._get_tokenizer()
648

649
650
651
652
653
        return [
            prompt_repl.bind(modality, tokenizer)
            for modality, metadata in self.metadata.items()
            if modality in mm_data for prompt_repl in metadata.prompt_repls
        ]
654

655
656
657
658
659
660
661
    def _apply_prompt_replacements(
        self,
        mm_data: MultiModalDataDict,
        hf_inputs: BatchFeature,
        token_ids: list[int],
        prompt_repls: Sequence[_BoundPromptReplacement[Any]],
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
662
        tokenizer = self._get_tokenizer()
663

664
665
666
667
668
669
670
671
672
673
674
675
676
677
        mm_items = to_multi_format(mm_data)
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
678
            len(matches) >= len(mm_items[modality])
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
                mm_items,
                hf_inputs,
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
                mm_items,
                hf_inputs,
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

        placeholders = self._find_placeholders(matched_repls, token_ids)

        return token_ids, text, placeholders
707

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
727
        tokenizer = self._get_tokenizer()
728
729
730
731
732

        hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
733

734
        all_prompt_repls = self._bind_prompt_replacements(mm_data)
735

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
        all_placeholders = self._find_placeholders(all_prompt_repls,
                                                   prompt_ids)
        if all_placeholders:
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                mm_data,
                hf_inputs,
                prompt_ids,
                all_prompt_repls,
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
758
759
760

        return MultiModalInputsV2(
            type="multimodal",
761
762
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
763
764
765
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821

    @abstractmethod
    def _get_dummy_mm_kwargs(
        self,
        mm_counts: Mapping[str, int],
    ) -> MultiModalKwargs:
        """
        Build the input that corresponds to `mm_max_tokens` in
        :meth:`get_dummy_data`.
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

        tokenizer = self._get_tokenizer()

        mm_placeholders = dict[str, _PlaceholderInfo]()
        offset = 0

        for modality, max_tokens in mm_max_tokens.items():
            if max_tokens == 0:
                continue

            metadata = self.metadata[modality]
            repl = metadata.prompt_repls[0].bind(modality, tokenizer)
            repl_token_ids = repl.repl_unit.token_ids

            placeholders = _PlaceholderInfo(
                modality=modality,
                start_idx=offset,
                unit=repl_token_ids,
                unit_count=max_tokens // len(repl_token_ids),
            )

            mm_placeholders[modality] = placeholders
            offset += placeholders.length

        prompt_token_ids = flatten_2d_lists(
            [p.unit * p.unit_count for p in mm_placeholders.values()])
        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
            multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
            multi_modal_placeholders={
                modality: [p.to_range()]
                for modality, p in mm_placeholders.items()
            },
        )