utils.py 6.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import mimetypes
5
import warnings
6
from collections.abc import Generator
7
from itertools import groupby
8
from typing import TYPE_CHECKING, Any
9

10
import numpy as np
11
import numpy.typing as npt
12
from PIL import Image
13

14
15
16
17
18
19
20
from vllm.utils.import_utils import LazyLoader

from .inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalPlaceholderDict,
21
)
22
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
23

24
if TYPE_CHECKING:
25
    import torch.types
26
else:
27
    torch = LazyLoader("torch", globals(), "torch")
28

29

30
31
32
def __getattr__(name: str):
    if name == "MEDIA_CONNECTOR_REGISTRY":
        from .media import MEDIA_CONNECTOR_REGISTRY
33

34
35
36
37
38
39
        warnings.warn(
            "`vllm.multimodal.utils.MEDIA_CONNECTOR_REGISTRY` "
            "has been moved to `vllm.multimodal.media.MEDIA_CONNECTOR_REGISTRY`. "
            "The old name will be removed in v0.17.",
            DeprecationWarning,
            stacklevel=2,
40
        )
41

42
        return MEDIA_CONNECTOR_REGISTRY
43

44
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
45

46

47
48
def encode_audio_base64(
    audio: np.ndarray,
49
    sampling_rate: int,
50
51
    *,
    format: str = "WAV",
52
53
) -> str:
    """Encode audio as base64."""
54
    audio_io = AudioMediaIO()
55
56
57
58
59
60
61
62
63
64
65
66
67
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
68
69


70
71
72
73
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
74
    format: str = "PNG",
75
76
77
) -> str:
    """
    Encode a pillow image to base64 format.
78

79
80
    By default, the image is converted into RGB format before being encoded.
    """
81
82
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
106
107
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
124
125


126
def argsort_mm_positions(
127
128
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
129
130
131
132
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
133
134

    Returns:
135
136
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
137
    """
138
139
140
141
142
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
143

144
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
145

146
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
147
148


149
def group_mm_kwargs_by_modality(
150
    mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
151
152
153
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
154
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
155
156
157
158
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
159
160
161
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
162
163
164
165

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
166
167
168
    for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
        items_lst = [item for _, item in group]
        mm_kwargs_items = MultiModalKwargsItems({modality: items_lst})
169
170
171
172
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
173

174
        yield modality, len(items_lst), mm_kwargs_data
175
176


177
178
def fetch_audio(
    audio_url: str,
179
180
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
181
182
183
184
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
185
186
187
188

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
189
    """
190
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
191
192
193
194
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
195
196
197
198
199
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
200
    image_io_kwargs: dict[str, Any] | None = None,
201
202
203
204
205
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
206
207
208
209

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
210
    """
211
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
212
213
214
215
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
216
217
218
219
220
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
221
    video_io_kwargs: dict[str, Any] | None = None,
222
223
224
225
226
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
227
228
229
230

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
231
    """
232
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
233
234
235
236
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
237
    return media_connector.fetch_video(video_url)