"vllm/model_executor/models/gemma2.py" did not exist on "f1c0fc391909e55fce5f109893f3c483f69a091f"
utils.py 9.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import mimetypes
5
6
from collections import defaultdict
from collections.abc import Generator, Sequence
7
from itertools import groupby
8
from typing import TYPE_CHECKING, Any
9

10
import numpy as np
11
import numpy.typing as npt
12
from PIL import Image
13
from typing_extensions import deprecated
14

15
from vllm.inputs import MultiModalPlaceholders
16
17
from vllm.utils.import_utils import LazyLoader

18
from .hasher import MultiModalHasher
19
20
from .inputs import (
    BatchedTensorInputs,
21
    MultiModalFieldElem,
22
    MultiModalKwargsItem,
23
    MultiModalSharedField,
24
)
25
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
26

27
if TYPE_CHECKING:
28
    import torch.types
29
else:
30
    torch = LazyLoader("torch", globals(), "torch")
31

32
33
34

def encode_audio_base64(
    audio: np.ndarray,
35
    sampling_rate: int,
36
37
    *,
    format: str = "WAV",
38
39
) -> str:
    """Encode audio as base64."""
40
    audio_io = AudioMediaIO()
41
42
43
44
45
46
47
48
49
50
51
52
53
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
54
55


56
57
58
59
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
60
    format: str = "PNG",
61
62
63
) -> str:
    """
    Encode a pillow image to base64 format.
64

65
66
    By default, the image is converted into RGB format before being encoded.
    """
67
68
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
69
70


71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
92
93
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
110
111


112
def argsort_mm_positions(
113
    mm_positions: MultiModalPlaceholders,
114
) -> list[tuple[str, int]]:
115
    """
116
    Given a `MultiModalPlaceholders`, output a sequence of keys to
117
118
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
119
120

    Returns:
121
122
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
123
    """
124
125
126
127
128
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
129

130
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
131

132
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
133
134


135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def _get_group_hash(elem: MultiModalFieldElem):
    if not isinstance(elem.field, MultiModalSharedField):
        return None

    return MultiModalHasher.hash_kwargs(data=elem.data)


def _batch_mm_items(
    items: Sequence[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
):
    elems = defaultdict[str, list[MultiModalFieldElem]](list)
    for item in items:
        for key, elem in item.items():
            elems[key].append(elem)

    return {
        key: elems[0].field.reduce_data(
            elems,
            device=device,
            pin_memory=pin_memory,
        )
        for key, elems in elems.items()
    }


def group_and_batch_mm_items(
    items: Sequence[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
) -> Generator[tuple[int, BatchedTensorInputs]]:
    """
    Group consecutive items (possibly from different requests) into batches.

    Items must be split across groups if any of the following occurs,
    as the batch would otherwise be invalid:
    - They have different fields (e.g. mixed image and embedding inputs).
    - They have different values in `MultiModalSharedField`.

    Args:
        items: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.

    Yields:
        A tuple `(num_items, grouped_kwargs)`, where:
        - `kwargs` is a dictionary of keyword arguments to pass to the model;
        - `num_items` is the corresponding number of items.
    """
    group_ids = [
        tuple(
            (key, _get_group_hash(elem))
            for key, elem in sorted(item.items(), key=lambda kv: kv[0])
        )
        for item in items
    ]
    group_sizes = [sum(1 for _ in group) for _, group in groupby(group_ids)]

    start_idx = 0
    for group_size in group_sizes:
        group_data = _batch_mm_items(
            items[start_idx : start_idx + group_size],
            device=device,
            pin_memory=pin_memory,
        )

        yield group_size, group_data

        start_idx += group_size

    assert start_idx == len(items)


211
def group_and_batch_mm_kwargs(
212
    mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
213
214
215
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
216
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
217
218
219
220
221
222
223
224
225
226
    """
    Group consecutive items (possibly from different requests) into batches.

    Items must be split across groups if any of the following occurs,
    as the batch would otherwise be invalid:
    - They have different fields (e.g. mixed image and embedding inputs).
    - They have different values in `MultiModalSharedField`.

    To simplify the implementation of `embed_multimodal`, we add another
    restriction that the items in a batch must belong to the same modality.
227
228

    Args:
229
        mm_kwargs: List of `(modality, item)`.
230
231
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
232
233

    Yields:
234
235
236
237
        A tuple `(modality, num_items, grouped_kwargs)`, where:
        - `modality` is the modality of the batch;
        - `kwargs` is a dictionary of keyword arguments to pass to the model;
        - `num_items` is the corresponding number of items.
238
    """
239
240
    for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
        items_lst = [item for _, item in group]
241
242
243

        for num_items, mm_kwargs_batch in group_and_batch_mm_items(
            items_lst,
244
245
            device=device,
            pin_memory=pin_memory,
246
247
        ):
            yield modality, num_items, mm_kwargs_batch
248
249


250
251
252
253
254
255
256
257
258
259
260
261
262
@deprecated(
    "`group_mm_kwargs_by_modality` has been renamed to `group_and_batch_mm_kwargs`. "
    "The old name will be removed in v0.19."
)
def group_mm_kwargs_by_modality(
    mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
    return group_and_batch_mm_kwargs(mm_kwargs, device=device, pin_memory=pin_memory)


263
264
def fetch_audio(
    audio_url: str,
265
266
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
267
268
269
270
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
271
272
273
274

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
275
    """
276
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
277
278
279
280
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
281
282
283
284
285
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
286
    image_io_kwargs: dict[str, Any] | None = None,
287
288
289
290
291
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
292
293
294
295

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
296
    """
297
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
298
299
300
301
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
302
303
304
305
306
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
307
    video_io_kwargs: dict[str, Any] | None = None,
308
309
310
311
312
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
313
314
315
316

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
317
    """
318
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
319
320
321
322
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
323
    return media_connector.fetch_video(video_url)