utils.py 9.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import mimetypes
5
import warnings
6
7
from collections import defaultdict
from collections.abc import Generator, Sequence
8
from itertools import groupby
9
from typing import TYPE_CHECKING, Any
10

11
import numpy as np
12
import numpy.typing as npt
13
from PIL import Image
14

15
16
from vllm.utils.import_utils import LazyLoader

17
from .hasher import MultiModalHasher
18
19
from .inputs import (
    BatchedTensorInputs,
20
    MultiModalFieldElem,
21
22
    MultiModalKwargsItem,
    MultiModalPlaceholderDict,
23
    MultiModalSharedField,
24
)
25
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
26

27
if TYPE_CHECKING:
28
    import torch.types
29
else:
30
    torch = LazyLoader("torch", globals(), "torch")
31

32

33
34
35
def __getattr__(name: str):
    if name == "MEDIA_CONNECTOR_REGISTRY":
        from .media import MEDIA_CONNECTOR_REGISTRY
36

37
38
39
40
41
42
        warnings.warn(
            "`vllm.multimodal.utils.MEDIA_CONNECTOR_REGISTRY` "
            "has been moved to `vllm.multimodal.media.MEDIA_CONNECTOR_REGISTRY`. "
            "The old name will be removed in v0.17.",
            DeprecationWarning,
            stacklevel=2,
43
        )
44

45
        return MEDIA_CONNECTOR_REGISTRY
46

47
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
48

49

50
51
def encode_audio_base64(
    audio: np.ndarray,
52
    sampling_rate: int,
53
54
    *,
    format: str = "WAV",
55
56
) -> str:
    """Encode audio as base64."""
57
    audio_io = AudioMediaIO()
58
59
60
61
62
63
64
65
66
67
68
69
70
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
71
72


73
74
75
76
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
77
    format: str = "PNG",
78
79
80
) -> str:
    """
    Encode a pillow image to base64 format.
81

82
83
    By default, the image is converted into RGB format before being encoded.
    """
84
85
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
86
87


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
109
110
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
127
128


129
def argsort_mm_positions(
130
131
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
132
133
134
135
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
136
137

    Returns:
138
139
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
140
    """
141
142
143
144
145
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
146

147
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
148

149
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
150
151


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def _get_group_hash(elem: MultiModalFieldElem):
    if not isinstance(elem.field, MultiModalSharedField):
        return None

    return MultiModalHasher.hash_kwargs(data=elem.data)


def _batch_mm_items(
    items: Sequence[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
):
    elems = defaultdict[str, list[MultiModalFieldElem]](list)
    for item in items:
        for key, elem in item.items():
            elems[key].append(elem)

    return {
        key: elems[0].field.reduce_data(
            elems,
            device=device,
            pin_memory=pin_memory,
        )
        for key, elems in elems.items()
    }


def group_and_batch_mm_items(
    items: Sequence[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
) -> Generator[tuple[int, BatchedTensorInputs]]:
    """
    Group consecutive items (possibly from different requests) into batches.

    Items must be split across groups if any of the following occurs,
    as the batch would otherwise be invalid:
    - They have different fields (e.g. mixed image and embedding inputs).
    - They have different values in `MultiModalSharedField`.

    Args:
        items: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.

    Yields:
        A tuple `(num_items, grouped_kwargs)`, where:
        - `kwargs` is a dictionary of keyword arguments to pass to the model;
        - `num_items` is the corresponding number of items.
    """
    group_ids = [
        tuple(
            (key, _get_group_hash(elem))
            for key, elem in sorted(item.items(), key=lambda kv: kv[0])
        )
        for item in items
    ]
    group_sizes = [sum(1 for _ in group) for _, group in groupby(group_ids)]

    start_idx = 0
    for group_size in group_sizes:
        group_data = _batch_mm_items(
            items[start_idx : start_idx + group_size],
            device=device,
            pin_memory=pin_memory,
        )

        yield group_size, group_data

        start_idx += group_size

    assert start_idx == len(items)


228
def group_mm_kwargs_by_modality(
229
    mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
230
231
232
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
233
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
234
235
236
237
238
239
240
241
242
243
    """
    Group consecutive items (possibly from different requests) into batches.

    Items must be split across groups if any of the following occurs,
    as the batch would otherwise be invalid:
    - They have different fields (e.g. mixed image and embedding inputs).
    - They have different values in `MultiModalSharedField`.

    To simplify the implementation of `embed_multimodal`, we add another
    restriction that the items in a batch must belong to the same modality.
244
245

    Args:
246
        mm_kwargs: List of `(modality, item)`.
247
248
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
249
250

    Yields:
251
252
253
254
        A tuple `(modality, num_items, grouped_kwargs)`, where:
        - `modality` is the modality of the batch;
        - `kwargs` is a dictionary of keyword arguments to pass to the model;
        - `num_items` is the corresponding number of items.
255
    """
256
257
    for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
        items_lst = [item for _, item in group]
258
259
260

        for num_items, mm_kwargs_batch in group_and_batch_mm_items(
            items_lst,
261
262
            device=device,
            pin_memory=pin_memory,
263
264
        ):
            yield modality, num_items, mm_kwargs_batch
265
266


267
268
def fetch_audio(
    audio_url: str,
269
270
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
271
272
273
274
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
275
276
277
278

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
279
    """
280
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
281
282
283
284
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
285
286
287
288
289
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
290
    image_io_kwargs: dict[str, Any] | None = None,
291
292
293
294
295
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
296
297
298
299

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
300
    """
301
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
302
303
304
305
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
306
307
308
309
310
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
311
    video_io_kwargs: dict[str, Any] | None = None,
312
313
314
315
316
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
317
318
319
320

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
321
    """
322
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
323
324
325
326
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
327
    return media_connector.fetch_video(video_url)