utils.py 6.91 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import mimetypes
5
import warnings
6
from collections.abc import Generator
7
from itertools import groupby
8
from typing import TYPE_CHECKING, Any
9

10
import numpy as np
11
import numpy.typing as npt
12
from PIL import Image
13

14
from vllm.logger import init_logger
15
16
17
18
19
20
21
from vllm.utils.import_utils import LazyLoader

from .inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalPlaceholderDict,
22
)
23
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
24

25
if TYPE_CHECKING:
26
    import torch.types
27
else:
28
    torch = LazyLoader("torch", globals(), "torch")
29

30
31
logger = init_logger(__name__)

32

33
34
35
def __getattr__(name: str):
    if name == "MEDIA_CONNECTOR_REGISTRY":
        from .media import MEDIA_CONNECTOR_REGISTRY
36

37
38
39
40
41
42
        warnings.warn(
            "`vllm.multimodal.utils.MEDIA_CONNECTOR_REGISTRY` "
            "has been moved to `vllm.multimodal.media.MEDIA_CONNECTOR_REGISTRY`. "
            "The old name will be removed in v0.17.",
            DeprecationWarning,
            stacklevel=2,
43
        )
44

45
        return MEDIA_CONNECTOR_REGISTRY
46

47
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
48

49

50
51
def encode_audio_base64(
    audio: np.ndarray,
52
    sampling_rate: int,
53
54
    *,
    format: str = "WAV",
55
56
) -> str:
    """Encode audio as base64."""
57
    audio_io = AudioMediaIO()
58
59
60
61
62
63
64
65
66
67
68
69
70
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
71
72


73
74
75
76
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
77
    format: str | None = None,
78
79
80
) -> str:
    """
    Encode a pillow image to base64 format.
81

82
83
    By default, the image is converted into RGB format before being encoded.
    """
84
85
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
86
87


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
109
110
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
127
128


129
def argsort_mm_positions(
130
131
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
132
133
134
135
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
136
137

    Returns:
138
139
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
140
    """
141
142
143
144
145
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
146

147
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
148

149
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
150
151


152
def group_mm_kwargs_by_modality(
153
    mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
154
155
156
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
157
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
158
159
160
161
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
162
163
164
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
165
166
167
168

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
169
170
171
    for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
        items_lst = [item for _, item in group]
        mm_kwargs_items = MultiModalKwargsItems({modality: items_lst})
172
173
174
175
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
176

177
        yield modality, len(items_lst), mm_kwargs_data
178
179


180
181
def fetch_audio(
    audio_url: str,
182
183
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
184
185
186
187
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
188
189
190
191

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
192
    """
193
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
194
195
196
197
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
198
199
200
201
202
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
203
    image_io_kwargs: dict[str, Any] | None = None,
204
205
206
207
208
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
209
210
211
212

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
213
    """
214
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
215
216
217
218
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
219
220
221
222
223
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
224
    video_io_kwargs: dict[str, Any] | None = None,
225
226
227
228
229
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
230
231
232
233

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
234
    """
235
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
236
237
238
239
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
240
    return media_connector.fetch_video(video_url)