base.py 6.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import Sequence
6
from pathlib import Path
7
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
8

9
10
11
if TYPE_CHECKING:
    from vllm.sequence import SequenceGroupMetadata

12
from .inputs import MultiModalKwargs, PlaceholderRange
13

14
_T = TypeVar("_T")
15
16
17
18
19


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
20
21

    Note: This is only used in V0.
22
23
24
    """

    class IndexMap(NamedTuple):
25
26
        src: list[int]
        dest: list[int]
27

28
    src_ranges: list[range]
29
30
31
32
33
34
35
36
37
38
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

39
    dest_ranges: list[range]
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
59
    ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
60
61
62
63
64
65
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

66
        Examples:
67

68
69
70
        ```
        Prompt:    |AAAA BBBB What's in these images?|
        Positions: |.................................|
71

72
73
74
            images      = [A, B]
            src_ranges  = [(0, 4), (4, 8)]
            dest_ranges = [(0, 4), (5, 9)]
75

76
77
        Prompt:    |AAAA BBBB What's in these images?|
        Positions: |  .....                          |
78

79
80
81
            images      = [A, B]
            src_ranges  = [(2, 4), (4, 6)]
            dest_ranges = [(0, 2), (3, 5)]
82

83
84
        Prompt:    |AAAA BBBB What's in these images?|
        Positions: |     .........                   |
85

86
87
88
            images      = [B]
            src_ranges  = [(0, 4)]
            dest_ranges = [(0, 4)]
89

90
91
        Prompt:    |AAAA BBBB What's in these images?|
        Positions: |          .......................|
92

93
94
95
96
            images      = []
            src_ranges  = []
            dest_ranges = []
        ```
97
        """
98
99
100
101
        seq_mm_data = seq_group.multi_modal_data
        seq_mm_placeholders = seq_group.multi_modal_placeholders

        if not seq_mm_data or not seq_mm_placeholders:
102
            return MultiModalKwargs({}), {}
103

104
        placeholder_maps = dict[str, MultiModalPlaceholderMap]()
105

106
        for modality, placeholders in seq_mm_placeholders.items():
107
            placeholder_map = MultiModalPlaceholderMap()
108
109

            if positions:
110
111
112
113
114
115
                placeholder_map.append_items_from_seq_group(
                    positions,
                    # Dummy, since we don't care about intersecting items
                    [None] * len(placeholders),
                    placeholders,
                )
116

117
            placeholder_maps[modality] = placeholder_map
118

119
        return seq_mm_data, placeholder_maps
120
121

    def append_items_from_seq_group(
122
123
        self,
        positions: range,
124
        multi_modal_items: list[_T],
125
        multi_modal_placeholders: Sequence[PlaceholderRange],
126
    ) -> list[_T]:
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
140
141
                placeholder_dict.offset,
                placeholder_dict.offset + placeholder_dict.length,
142
143
144
145
146
            )
            intersection = range(
                max(positions.start, placeholder.start),
                min(positions.stop, placeholder.stop),
            )
147
148
149
150
151

            if not intersection:
                # Skip this multi-modal item.
                continue

152
153
154
155
            token_embedding_range = range(
                intersection.start - positions.start,
                intersection.stop - positions.start,
            )
156
157
158

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
159
160
                intersection.stop - placeholder.start + self.src_len,
            )
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

200
        return self.IndexMap(src=src_indices, dest=dest_indices)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219


class MediaIO(ABC, Generic[_T]):

    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError