processing.py 27.1 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import UserDict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
import torch
11
from PIL.Image import Image
12
from transformers import BatchFeature, ProcessorMixin
13
from typing_extensions import assert_never
14

15
from vllm.inputs import DummyData, InputProcessingContext
16
from vllm.logger import init_logger
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
18
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
19

20
from .audio import resample_audio
21
22
23
24
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

25
logger = init_logger(__name__)
26
27

_S = TypeVar("_S", str, list[int])
28
_PromptSeq = Union[str, list[int]]
29

30
31

@dataclass
32
33
class PromptReplacement:
    modality: str
34
    """The modality for which the replacement is made."""
35

36
37
    target: _PromptSeq
    """The text or token sequence to find and replace."""
38

39
40
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
41
    """
42
43
    Given the index of the processed item within :attr:`modality`, output the
    replacement text or token sequence.
44

45
46
    For convenience, you can pass in the replacement instead of a function
    if it does not depend on the input.
47
48
    """

49
    def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
50
        return _BoundPromptReplacement(
51
52
53
54
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
55
        )
56
57


58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
72

73
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
74
75


76
77
78
79
80
81
82
83
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
84
85


86
87
88
89
90
91
92
93
94
95
96
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

114

115
class _HasModalityProp(Protocol):
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
132
133
    tokenizer: AnyTokenizer = field(repr=False)

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
160
161
class _BoundPromptReplacement:
    tokenizer: AnyTokenizer = field(repr=False)
162
163
    modality: str

164
165
166
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
167

168
169
170
171
172
173
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
        target = self._target
174

175
176
177
178
179
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
210
    """
211
212
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
213
    """
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    @staticmethod
    def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
        """
        multi_data = MultiModalDataItems()

        for k, v in data.items():
            # yapf: disable
            if k == "video":
                # Special case since even a single item can be a list
                multi_data[k] = (  # type: ignore[index]
                    v if is_list_of(v, (list, torch.Tensor)) else [v]
                )
            elif k in ("image", "audio"):
                multi_data[k] = (  # type: ignore[index]
                    v if isinstance(v, (list, torch.Tensor)) else [v]
                )
            else:
                multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
            # yapf: enable

        return multi_data

    # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
    # `self.images` doesn't update this dictionary, which may be confusing
    # We annotate the getter methods as `Sequence` to prevent others from
    # trying to update the list in this way
243
    @property
244
245
    def images(self) -> Sequence[ImageItem]:
        return self.get("image", [])
246
247

    @property
248
249
    def videos(self) -> Sequence[VideoItem]:
        return self.get("video", [])
250
251

    @property
252
253
    def audios(self) -> Sequence[AudioItem]:
        return self.get("audio", [])
254
255

    def get_image_size(self, item_idx: int) -> ImageSize:
256
        image = self.images[item_idx]
257
258
259
260
261
262
263
264
265

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def get_audio_with_sr(
        self,
        item_idx: int,
        *,
        default_sr: float,
    ) -> tuple[np.ndarray, float]:
        audio = self.audios[item_idx]

        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), default_sr
        if isinstance(audio, np.ndarray):
            return audio, default_sr

        assert_never(audio)

    def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
        """
        If :code:`drop_sr=True`, the audio items in this dictionary are updated
        to be NumPy arrays which implicitly means that their sampling rate is
        the same as the model's expected sampling rate; otherwise, they remain
        as :code:`(audio, new_sr)` tuples.
        """
        if not self.audios:
            return

        new_audios = []
        for item_idx in range(len(self.audios)):
            audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
            audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
297

298
            new_audios.append(audio if drop_sr else (audio, new_sr))
299

300
        self["audio"] = new_audios
301
302


303
304
305
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
306
307


308
309
310
311
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
312
313
314
315
316
317
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
318
    match_len = len(match_ids)
319

320
321
    if match_len == 0:
        return
322

323
324
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
325
        end_idx = start_idx + match_len
326

327
328
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
329
330
331
332
333

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
334
335


336
337
338
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
    prompt_repl: _BoundPromptReplacement
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
360
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
361
362
363
364
365
366
367
368
369
370
371
372
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
373
class _PromptReplacementTextMatch(_PromptReplacementMatch):
374
375
376
377
378
379
380
381
382
383
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

384
385
386
387

class _PlaceholderInfo(NamedTuple):
    modality: str
    start_idx: int
388
    replacement: list[int]
389
390
391

    @property
    def length(self) -> int:
392
        return len(self.replacement)
393
394
395
396
397
398

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
399
400
401
402


def find_token_matches(
    prompt: list[int],
403
404
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
405
406
407
408
409
410
411
412
413
414
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
415
416
    prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
417
418
419
420
421
422
423
424
425
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
426
427
428
    prompt: _PromptSeq,
    matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
429
430
431
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
432
    """
433
434
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
435

436
    for match in matches:
437
438
439
440
441
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
442

443
            seen_matches[idx] = match
444
445
446
447
448
449

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
450
    matches: Sequence[_PromptReplacementMatch],
451
    mm_item_counts: Mapping[str, int],
452
453
454
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
455
    next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
456
457
458
459
460

    for match in _resolve_matches(prompt, matches):
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
461
        if item_idx >= mm_item_counts[modality]:
462
463
464
465
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
466

467
        repl_info = match.prompt_repl
468
469
470
471
472
473
474
475
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
476
477
478
479
480
481
482
483
484
485
486

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
487
    matches: Sequence[_PromptReplacementTokenMatch],
488
    mm_item_counts: Mapping[str, int],
489
490
491
492
493
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

494
    token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)
495
496

    return flatten_2d_lists(token_id_seqs)
497
498


499
500
def replace_text_matches(
    prompt: str,
501
    matches: Sequence[_PromptReplacementTextMatch],
502
    mm_item_counts: Mapping[str, int],
503
504
505
506
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
507

508
    texts = _replace_matches(prompt, matches, mm_item_counts)
509
510

    return "".join(texts)
511
512


513
514
515
516
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
    modality_repls: Sequence[_BoundPromptReplacement],
517
    modal_item_count: int,
518
) -> Iterable[_PlaceholderInfo]:
519
    if modal_item_count == 0:
520
        return
521

522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    prompt_len = len(prompt)
    item_index = 0

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
            replacement = repl_info.get_replacement(item_index)
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
537

538
539
540
541
542
543
544
545
            if prompt[start_idx:end_idx] == repl_tokens:
                yield _PlaceholderInfo(
                    modality=modality,
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

                item_index += 1
546
                if item_index >= modal_item_count:
547
548
549
550
551
552
553
554
555
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
556
557
558


def iter_placeholders(
559
    prompt_repls: Sequence[_BoundPromptReplacement],
560
    prompt: list[int],
561
    mm_item_counts: Mapping[str, int],
562
) -> Iterable[_PlaceholderInfo]:
563
564
565
566
567
568
569
    """
    Yield each set of placeholder tokens found in :code:`prompt`.

    Note that empty matches are ignored.
    """
    repls_by_modality = dict(full_groupby_modality(prompt_repls))

570
    for modality, modal_item_count in mm_item_counts.items():
571
572
573
574
575
        if modality in repls_by_modality:
            yield from _iter_modality_placeholders(
                prompt,
                modality,
                repls_by_modality[modality],
576
                modal_item_count,
577
578
            )

579

580
581
582
583
584
class ProcessorInputs(NamedTuple):
    """Keyword arguments to :meth:`BaseMultiModalProcessor`"""
    prompt_text: str
    mm_data: MultiModalDataDict
    mm_processor_kwargs: Mapping[str, object]
585
586


587
class BaseMultiModalProcessor(ABC):
588
    """
589
    Abstract base class to process multi-modal inputs to be used in vLLM.
590
591
    """

592
    def __init__(self, ctx: InputProcessingContext) -> None:
593
594
595
596
        super().__init__()

        self.ctx = ctx

597
    def __call__(
598
        self,
599
600
601
602
603
604
605
606
607
608
609
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def _get_hf_processor(self) -> ProcessorMixin:
        """
        Subclasses can add keyword arguments to this method to accept
        additional kwargs from model config or user inputs.
        """
610
611
612
613
614
        return self.ctx.get_hf_processor()

    def _get_tokenizer(self) -> AnyTokenizer:
        return self.ctx.tokenizer

615
616
    @abstractmethod
    def _get_prompt_replacements(
617
        self,
618
619
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
620
        mm_processor_kwargs: Mapping[str, object],
621
622
623
624
625
626
627
628
629
630
631
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

        Note:
            Even when the HF processor already performs replacement for us,
            we still use this replacement information to determine
            the placeholder token positions for each multi-modal item.
        """
        raise NotImplementedError
632

633
634
    def _find_placeholders(
        self,
635
        all_prompt_repls: Sequence[_BoundPromptReplacement],
636
        new_token_ids: list[int],
637
        mm_item_counts: Mapping[str, int],
638
639
    ) -> list[_PlaceholderInfo]:
        return list(
640
            iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
641

642
    def _get_processor_data(
643
        self,
644
645
        mm_items: MultiModalDataItems,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
646
647
        processor_data = dict[str, Any]()
        passthrough_data = dict[str, Any]()
648
649

        for k, v in mm_items.items():
650
651
652
653
654
655
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            if k in ("image", "video", "audio"):
                if isinstance(v, torch.Tensor) and v.ndim == 3:
                    # Pass through embedding inputs (single)
                    passthrough_data[f"{k}_embeds"] = [v]
656
657
                elif (is_list_of(v, torch.Tensor) and len(v) > 0
                      and v[0].ndim == 2):
658
659
660
661
662
663
664
                    # Pass through embedding inputs (multi)
                    passthrough_data[f"{k}_embeds"] = v
                else:
                    # Map keys to plural form, e.g.: image -> images
                    processor_data[f"{k}s"] = v
            else:
                processor_data[k] = v
665

666
667
        return processor_data, passthrough_data

668
669
670
671
672
673
674
675
676
677
678
679
680
681
    def _call_hf_processor(
        self,
        hf_processor: ProcessorMixin,
        prompt: str,
        processor_data: Mapping[str, object],
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        return self.ctx.call_hf_processor(
            hf_processor,
            prompt,
            processor_data,
            mm_processor_kwargs,
        )

682
683
684
    def _apply_hf_processor(
        self,
        prompt: str,
685
        mm_items: MultiModalDataItems,
686
687
688
689
690
691
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # some mm_processor_kwargs may be used in processor initialization
        # instead of processor call
        hf_processor = self._get_hf_processor(**mm_processor_kwargs)

692
        processor_data, passthrough_data = self._get_processor_data(mm_items)
693

694
        hf_inputs = self._call_hf_processor(
695
            hf_processor,
696
697
698
            prompt=prompt,
            processor_data=processor_data,
            mm_processor_kwargs=mm_processor_kwargs,
699
        )
700
701
702
        hf_inputs.update(passthrough_data)

        return hf_inputs
703

704
705
    def _bind_prompt_replacements(
        self,
706
707
        prompt_repls: list[PromptReplacement],
    ) -> list[_BoundPromptReplacement]:
708
        tokenizer = self._get_tokenizer()
709

710
        return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
711

712
713
714
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
715
        prompt_repls: Sequence[_BoundPromptReplacement],
716
        mm_item_counts: Mapping[str, int],
717
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
718
        tokenizer = self._get_tokenizer()
719

720
721
722
723
724
725
726
727
728
729
730
731
732
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
733
            len(matches) >= mm_item_counts[modality]
734
735
736
737
738
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
739
                mm_item_counts,
740
741
742
743
744
745
746
747
748
749
750
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
751
                mm_item_counts,
752
753
754
755
756
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

757
        placeholders = self._find_placeholders(matched_repls, token_ids,
758
                                               mm_item_counts)
759
760

        return token_ids, text, placeholders
761

762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
781
        mm_items = MultiModalDataItems.from_dict(mm_data)
782

783
        hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
784
785
786
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
787

788
789
790
        prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
                                                     mm_processor_kwargs)
        all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
791

792
793
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
794
        mm_item_counts = {m: len(items) for m, items in mm_items.items()}
795
        all_placeholders = self._find_placeholders(all_prompt_repls,
796
                                                   prompt_ids, mm_item_counts)
797

798
        if all_placeholders:
799
            tokenizer = self._get_tokenizer()
800
801
802
803
804
805
806
807
808
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                prompt_ids,
                all_prompt_repls,
809
                mm_item_counts,
810
811
812
813
814
815
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
816
817
818

        return MultiModalInputsV2(
            type="multimodal",
819
820
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
821
822
823
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )
824
825

    @abstractmethod
826
    def _get_dummy_mm_inputs(
827
828
        self,
        mm_counts: Mapping[str, int],
829
    ) -> ProcessorInputs:
830
        """
831
832
        Build the multi-modal portion of the input which, after processing,
        results in `mm_max_tokens` in :meth:`get_dummy_data`.
833
834
835
836
837
838
839
840
841
842
843
844
        """
        raise NotImplementedError

    def get_dummy_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_max_tokens: Mapping[str, int],
    ) -> DummyData:
        # Avoid circular import
        from vllm.sequence import SequenceData

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        processor_inputs = self._get_dummy_mm_inputs(mm_counts)
        mm_inputs = self.apply(*processor_inputs)

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = dict[str, int]()
        for modality, placeholders in placeholders_by_modality.items():
            num_placeholders = sum(item["length"] for item in placeholders)
            max_tokens = mm_max_tokens[modality]

            if num_placeholders != max_tokens:
                logger.warning(
                    "The processed dummy data has a total of %d placeholder "
                    "tokens for the '%s' modality, which is not the expected "
                    "%d tokens.", num_placeholders, modality, max_tokens)

            total_placeholders_by_modality[modality] = num_placeholders

        total_len = len(prompt_token_ids)
        if total_len > seq_len:
            logger.warning(
                "The context length (%d) of the model is too short "
                "to hold the multi-modal embeddings in the worst case "
                "(%d tokens in total, out of which %s are reserved for "
                "multi-modal embeddings). This may cause certain multi-modal "
                "inputs to fail during inference, even when the input text is "
                "short. To avoid this, you should increase `max_model_len`, "
                "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
                total_len, total_placeholders_by_modality)
875
876
877
878
879

        prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))

        return DummyData(
            seq_data=SequenceData.from_seqs(prompt_token_ids),
880
881
            multi_modal_data=mm_inputs["mm_kwargs"],
            multi_modal_placeholders=placeholders_by_modality,
882
        )