inputs.py 23.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC, abstractmethod
4
from collections import UserDict, defaultdict
5
6
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
7
from functools import partial
8
from itertools import accumulate
9
10
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
                    Union, cast, final)
11
12
13
14
15

import numpy as np
import torch
import torch.types
from PIL.Image import Image
16
from transformers import BatchFeature
17
from typing_extensions import NotRequired, TypeAlias
18

19
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
20

21
22
23
if TYPE_CHECKING:
    from .hasher import MultiModalHashDict

24
25
_T = TypeVar("_T")

26
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
27
"""
28
29
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
30
31
"""

32
33
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
                               list[np.ndarray], list[torch.Tensor]]
34
"""
35
36
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
37
38
"""

39
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
40
"""
41
42
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
43
44
"""

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
                             torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
81
82
83
84
85
86
87
88
89
90
91
92
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

93
    image: ModalityData[ImageItem]
94
95
    """The input image(s)."""

96
    video: ModalityData[VideoItem]
97
98
    """The input video(s)."""

99
    audio: ModalityData[AudioItem]
100
101
102
    """The input audio(s)."""


103
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
104
105
"""
A dictionary containing an entry for each modality type to input.
106
107

The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
108
109
110
111
112
113
114
"""


class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

115
116
117
118
    Example:

        Prompt: :code:`AAAA BBBB What is in these images?`

119
        Images A and B will have:
120
121
122

        .. code-block::

123
124
125
126
127
128
129
130
131
132
133
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


134
135
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
                      tuple[torch.Tensor, ...]]
136
137
138
139
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

140
141
142
143

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
    """Equality check between :data:`NestedTensors` objects."""
    if isinstance(a, torch.Tensor):
144
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
145
    elif isinstance(b, torch.Tensor):
146
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
147
148
149
150
151
152
153
154
155
156
157
158
159

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
160
161
162
163
164
165
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""


166
@dataclass(frozen=True)
167
class MultiModalFieldElem:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    """
    Represents a keyword argument corresponding to a multi-modal item
    in :class:`MultiModalKwargs`.
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
    The key of this field in :class:`MultiModalKwargs`,
    i.e. the name of the keyword argument to be passed to the model.
    """

185
    data: NestedTensors
186
187
188
189
190
191
192
193
194
195
    """
    The tensor data of this field in :class:`MultiModalKwargs`,
    i.e. the value of the keyword argument to be passed to the model.
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
196
197
198
199
200

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

201
202
203
        return ((self.modality, self.key) == (other.modality, other.key)
                and nested_tensors_equal(self.data, other.data)
                and type(self.field) == type(other.field))  # noqa: E721
204
205
206
207


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    """
    Defines how to interpret tensor data belonging to a keyword argument in
    :class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
226
227

    @abstractmethod
228
229
230
231
232
233
234
235
236
237
238
239
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
        Construct :class:`MultiModalFieldElem` instances to represent
        the provided data.
        
        This is the inverse of :meth:`reduce_data`.
        """
240
241
        raise NotImplementedError

242
243
244
    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError
245

246
247
248
    def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
        """
        Merge the data from multiple instances of :class:`MultiModalFieldElem`.
249

250
251
252
253
254
        This is the inverse of :meth:`build_elems`.
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
255

256
        return self._reduce_data([item.data for item in elems])
257
258
259
260
261


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
262
263
    See also:
        :func:`MultiModalFieldConfig.batched`
264
265
    """

266
267
268
269
270
271
272
273
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
274
275
276

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
277
278
279
280
281
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
282
            first_shape = batch[0].shape
283
            if all(elem.shape == first_shape for elem in batch):
284
285
286
287
288
289
290
291
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
292
293
294
    See also:
        :func:`MultiModalFieldConfig.flat`
        :func:`MultiModalFieldConfig.flat_from_sizes`
295
    """
296
    slices: Sequence[slice]
297

298
    def build_elems(
299
        self,
300
301
302
303
304
305
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data[s]) for s in self.slices]
306
307
308

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
309
310
311
312
313
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
314
            first_shape = batch[0].shape
315
            if all(elem.shape[1:] == first_shape[1:] for elem in batch):
316
317
                return torch.concat(batch)

318
        return [e for elem in batch for e in elem]
319
320


321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
    See also:
        :func:`MultiModalFieldConfig.shared`
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        return batch[0]


342
343
344
345
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
346
347
348
349
350
351
352
353
354
355
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

356
            .. code-block::
357

358
359
360
361
                Input:
                    Data: [[AAAA]
                        [BBBB]
                        [CCCC]]
362

363
364
365
366
                Output:
                    Element 1: [AAAA]
                    Element 2: [BBBB]
                    Element 3: [CCCC]
367
        """
368
        return MultiModalFieldConfig(
369
            field=MultiModalBatchedField(),
370
371
372
373
374
            modality=modality,
        )

    @staticmethod
    def flat(modality: str, slices: Sequence[slice]):
375
376
377
378
379
380
381
382
383
384
385
386
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, a slice that is used to extract
                the data corresponding to it.

        Example:

387
388
389
390
            .. code-block::
        
                Given:
                    slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
391

392
393
                Input:
                    Data: [AAABBBBCC]
394

395
396
397
398
                Output:
                    Element 1: [AAA]
                    Element 2: [BBBB]
                    Element 3: [CC]
399
        """
400
        return MultiModalFieldConfig(
401
            field=MultiModalFlatField(slices=slices),
402
403
404
            modality=modality,
        )

405
406
    @staticmethod
    def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
407
408
409
410
411
412
413
414
415
416
417
418
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.

        Example:

419
420
421
422
            .. code-block::
        
                Given:
                    size_per_item: [3, 4, 2]
423

424
425
                Input:
                    Data: [AAABBBBCC]
426

427
428
429
430
                Output:
                    Element 1: [AAA]
                    Element 2: [BBBB]
                    Element 3: [CC]
431
432
433
434
435
    
        See also:
            :func:`MultiModalFieldConfig.flat`
        """

436
437
438
439
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

440
441
442
443
444
445
446
447
        slice_idxs = [0, *accumulate(size_per_item)]
        slices = [
            slice(slice_idxs[i], slice_idxs[i + 1])
            for i in range(len(size_per_item))
        ]

        return MultiModalFieldConfig.flat(modality, slices)

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

463
464
465
466
            .. code-block::
        
                Given:
                    batch_size: 4
467

468
469
                Input:
                    Data: [XYZ]
470

471
472
473
474
475
                Output:
                    Element 1: [XYZ]
                    Element 2: [XYZ]
                    Element 3: [XYZ]
                    Element 4: [XYZ]
476
477
478
479
480
481
482
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
483
484
        super().__init__()

485
        self.field = field
486
        self.modality = modality
487

488
    def build_elems(
489
490
491
        self,
        key: str,
        batch: NestedTensors,
492
    ) -> Sequence[MultiModalFieldElem]:
493
        return self.field.build_elems(self.modality, key, batch)
494
495


496
497
498
499
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
    A collection of :class:`MultiModalFieldElem`
    corresponding to a data item in :class:`MultiModalDataItems`.
500
    """
501

502
503
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
504
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
505

506
507
    @property
    def modality(self) -> str:
508
        modalities = {elem.modality for elem in self.data.values()}
509
510
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
511
512


513
514
515
516
517
518
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
519

520
521
522
    The metadata :code:`items` enables us to obtain the keyword arguments
    corresponding to each data item in :class:`MultiModalDataItems`, via
    :meth:`get_item` and :meth:`get_items`.
523
524
    """

525
526
527
528
529
530
531
    @staticmethod
    def from_hf_inputs(
        hf_inputs: BatchFeature,
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
558
559

    @staticmethod
560
561
562
563
564
565
566
    def from_items(items: Sequence[MultiModalKwargsItem]):
        """Construct a new :class:`MultiModalKwargs` from multiple items."""
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

567
        data = {
568
            key: elems[0].field.reduce_data(elems)
569
            for key, elems in elems_by_key.items() if len(elems) > 0
570
571
        }

572
        return MultiModalKwargs(data, items=items)
573
574
575
576
577

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
578
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
579
580
581
    ) -> None:
        super().__init__(data)

582
583
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
584

585
586
587
    @property
    def modalities(self):
        return self._items_by_modality.keys()
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

612
        tensors_ = cast(list[torch.Tensor], stacked)
613
614
615
616
617
618
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

619
620
621
622
623
624
625
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
626
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
641
        item_lists = defaultdict[str, list[NestedTensors]](list)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

667
668
669
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
670
        if self._items_by_modality != other._items_by_modality:
671
672
673
674
675
676
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

677
678
679
680
681
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
682

683
684
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
685
686
687
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

688
689
690
691
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
692

693
694
695
696
697
698
699
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
700

701
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
702
        """
703
704
        Get the keyword arguments corresponding to each item belonging to
        a modality.
705
        """
706
707
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
708

709
710
711

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
712
A dictionary containing placeholder ranges for each modality.
713
714
715
"""


716
class MultiModalInputs(TypedDict):
717
    """
718
719
    Represents the outputs of
    :class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
720
721
722
723
724
725
726
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
727
    """The processed prompt text."""
728

729
    prompt_token_ids: list[int]
730
731
    """The processed token IDs which includes placeholder tokens."""

732
    token_type_ids: NotRequired[list[int]]
733
734
    """The token type IDs of the prompt."""

735
736
737
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

738
    mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
739
740
    """The hashes of the multi-modal data."""

741
742
743
744
745
    mm_placeholders: MultiModalPlaceholderDict
    """
    For each modality, information about the placeholder tokens in
    :code:`prompt_token_ids`.
    """
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761


class MultiModalEncDecInputs(MultiModalInputs):
    """
    Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""

    encoder_token_type_ids: NotRequired[list[int]]
    """The token type IDs of the encoder prompt."""