utils.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
from collections.abc import Iterable
7
from concurrent.futures import ThreadPoolExecutor
8
from itertools import groupby
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeVar
11
from urllib.parse import ParseResult, urlparse
12
from urllib.request import url2pathname
13

14
import numpy as np
15
import numpy.typing as npt
16
import torch
17
from PIL import Image, UnidentifiedImageError
18

19
import vllm.envs as envs
20
from vllm.connections import HTTPConnection, global_http_connection
21
from vllm.logger import init_logger
22
from vllm.utils.jsontree import json_map_leaves
23
from vllm.utils.registry import ExtensionManager
24

25
26
from .audio import AudioMediaIO
from .base import MediaIO
27
from .image import ImageEmbeddingMediaIO, ImageMediaIO
28
from .video import VideoMediaIO
29

30
if TYPE_CHECKING:
31
32
33
34
35
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
36
else:
37
38
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
39
    MultiModalPlaceholderDict = Any
40

41
42
logger = init_logger(__name__)

43
global_thread_pool = ThreadPoolExecutor(
44
45
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
46
47
atexit.register(global_thread_pool.shutdown)

48
49
_M = TypeVar("_M")

50
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
51

52
53

@MEDIA_CONNECTOR_REGISTRY.register("http")
54
55
56
class MediaConnector:
    def __init__(
        self,
57
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
58
59
60
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
61
        allowed_media_domains: list[str] | None = None,
62
    ) -> None:
63
64
        """
        Args:
65
66
67
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
68
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
69
            connection: HTTP connection client to download media contents.
70
71
            allowed_local_media_path: A local directory to load media files
                                      from.
72
        """
73
74
        super().__init__()

75
76
77
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
78
79
80
81
82
83
84
85
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
86
87
                    f"{allowed_local_media_path_} does not exist."
                )
88
89
90
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
91
92
                    f"{allowed_local_media_path_} must be a directory."
                )
93
94
95
96
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
97
98
99
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
100
101
102
103
104

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
105
    ) -> _M:  # type: ignore[type-var]
106
107
108
109
110
111
112
113
114
115
116
117
118
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
119
    ) -> _M:  # type: ignore[type-var]
120
121
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
122
123
124
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
125

126
        filepath = Path(url2pathname(url_spec.path))
127
        if allowed_local_media_path not in filepath.resolve().parents:
128
            raise ValueError(
129
                f"The file path {filepath} must be a subpath "
130
131
                f"of `--allowed-local-media-path` {allowed_local_media_path}."
            )
132

133
        return media_io.load_file(filepath)
134

135
    def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
136
137
138
139
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
140
141
142
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
143
144
                f"{url_spec.hostname}"
            )
145

146
147
148
149
150
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
151
        fetch_timeout: int | None = None,
152
    ) -> _M:  # type: ignore[type-var]
153
        url_spec = urlparse(url)
154

155
        if url_spec.scheme.startswith("http"):
156
157
            self._assert_url_in_allowed_media_domains(url_spec)

158
            connection = self.connection
159
160
161
162
163
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
164

165
            return media_io.load_bytes(data)
166

167
168
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
169

170
171
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
172

173
174
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
175

176
177
178
179
180
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
181
        fetch_timeout: int | None = None,
182
183
    ) -> _M:
        url_spec = urlparse(url)
184
        loop = asyncio.get_running_loop()
185

186
        if url_spec.scheme.startswith("http"):
187
188
            self._assert_url_in_allowed_media_domains(url_spec)

189
            connection = self.connection
190
191
192
193
194
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
195
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
196
            return await future
197

198
        if url_spec.scheme == "data":
199
200
201
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
202
            return await future
203

204
        if url_spec.scheme == "file":
205
206
207
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
208
            return await future
209
210
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
211

212
213
214
    def fetch_audio(
        self,
        audio_url: str,
215
    ) -> tuple[np.ndarray, int | float]:
216
217
218
        """
        Load audio from a URL.
        """
219
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
220

221
        return self.load_from_url(
222
            audio_url,
223
224
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
225
        )
226

227
228
229
    async def fetch_audio_async(
        self,
        audio_url: str,
230
    ) -> tuple[np.ndarray, int | float]:
231
232
233
        """
        Asynchronously fetch audio from a URL.
        """
234
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
235

236
        return await self.load_from_url_async(
237
            audio_url,
238
239
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
240
        )
241

242
243
244
245
246
247
248
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
249
        Load a PIL image from an HTTP or base64 data URL.
250

251
252
        By default, the image is converted into RGB format.
        """
253
254
255
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
256

257
258
259
260
261
262
263
264
265
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
266

267
268
    async def fetch_image_async(
        self,
269
270
        image_url: str,
        *,
271
272
273
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
274
        Asynchronously load a PIL image from an HTTP or base64 data URL.
275

276
277
        By default, the image is converted into RGB format.
        """
278
279
280
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
281

282
283
284
285
286
287
288
289
290
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
291

292
293
294
295
296
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
297
    ) -> tuple[npt.NDArray, dict[str, Any]]:
298
        """
299
        Load video from an HTTP or base64 data URL.
300
        """
301
302
303
304
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
305
306
307
308
309
310

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
311

312
313
314
    async def fetch_video_async(
        self,
        video_url: str,
315
        *,
316
        image_mode: str = "RGB",
317
    ) -> tuple[npt.NDArray, dict[str, Any]]:
318
        """
319
        Asynchronously load video from an HTTP or base64 data URL.
320
321
322

        By default, the image is converted into RGB format.
        """
323
324
325
326
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
327
328
329
330
331
332

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
333

334
335
336
337
338
339
340
341
342
343
344
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

345

346
347
def encode_audio_base64(
    audio: np.ndarray,
348
    sampling_rate: int,
349
350
) -> str:
    """Encode audio as base64."""
351
352
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
353
354


355
356
357
358
359
360
361
362
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
363

364
365
    By default, the image is converted into RGB format before being encoded.
    """
366
367
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
368
369


370
def encode_video_base64(frames: npt.NDArray) -> str:
371
372
373
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
374
375


376
def argsort_mm_positions(
377
378
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
379
380
381
382
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
383
384

    Returns:
385
386
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
387
    """
388
389
390
391
392
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
393

394
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
395

396
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
397
398


399
400
401
402
403
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
404
    merge_by_field_config: bool | None = None,
405
406
407
408
409
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
410
411
412
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
413
414
415
416

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
417
418
419
420
    if merge_by_field_config is None:
        raise RuntimeError(
            "`group_mm_kwargs_by_modality` now requires "
            "`merge_by_field_config` arg, please update your model runner "
421
422
            "according to https://github.com/vllm-project/vllm/pull/25676."
        )
423
424
425
426
427
428
429
430
431
    if merge_by_field_config is False:
        logger.warning_once(
            "The legacy code for batching multi-modal kwargs is deprecated and "
            "will be removed in v0.12. Please update your model with "
            "`merge_by_field_config=True` to use the new code defined by "
            "`MultiModalFieldConfig`. You can refer to "
            "https://github.com/vllm-project/vllm/issues/26149 "
            "for some examples on how to do this."
        )
432

433
    from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
434
435
436
437

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)

438
439
440
        if merge_by_field_config:
            mm_kwargs_group: BatchedTensorInputs = dict(
                MultiModalKwargsItems.from_seq(items_lst).get_data(
441
442
443
                    pin_memory=pin_memory
                )
            )
444
445
446

            if device is not None:
                mm_kwargs_group = json_map_leaves(
447
448
449
                    lambda x: x.to(device=device, non_blocking=True)
                    if isinstance(x, torch.Tensor)
                    else x,
450
451
452
453
454
455
456
457
458
459
460
461
462
                    mm_kwargs_group,
                )
        else:
            mm_kwargs_group = MultiModalKwargs.as_kwargs(
                MultiModalKwargs.batch(
                    [
                        MultiModalKwargsItems.from_seq([item]).get_data()
                        for item in items_lst
                    ],
                    pin_memory=pin_memory,
                ),
                device=device,
            )
463
464
465
466

        yield modality, len(items_lst), mm_kwargs_group


467
468
def fetch_audio(
    audio_url: str,
469
470
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
471
472
473
474
475
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
    """
476
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
477
478
479
480
481
482
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
483
    image_io_kwargs: dict[str, Any] | None = None,
484
485
486
487
488
489
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
    """
490
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
491
492
493
494
495
496
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
497
    video_io_kwargs: dict[str, Any] | None = None,
498
499
500
501
502
503
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
    """
504
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
505
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
506
    return media_connector.fetch_video(video_url)