utils.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
from collections.abc import Generator, Set
7
from concurrent.futures import ThreadPoolExecutor
8
from itertools import groupby
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeVar
11
from urllib.parse import ParseResult, urlparse
12
from urllib.request import url2pathname
13

14
import numpy as np
15
import numpy.typing as npt
16
import torch
17
from PIL import Image, UnidentifiedImageError
18

19
import vllm.envs as envs
20
from vllm.connections import HTTPConnection, global_http_connection
21
from vllm.logger import init_logger
22
from vllm.utils.jsontree import json_map_leaves
23
from vllm.utils.registry import ExtensionManager
24

25
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
26
from .base import MediaIO
27
from .image import ImageEmbeddingMediaIO, ImageMediaIO
28
from .video import VideoMediaIO
29

30
if TYPE_CHECKING:
31
32
33
34
35
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
36
else:
37
38
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
39
    MultiModalPlaceholderDict = Any
40

41
42
logger = init_logger(__name__)

43
global_thread_pool = ThreadPoolExecutor(
44
45
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
46
47
atexit.register(global_thread_pool.shutdown)

48
49
_M = TypeVar("_M")

50
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
51

52
53

@MEDIA_CONNECTOR_REGISTRY.register("http")
54
55
56
class MediaConnector:
    def __init__(
        self,
57
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
58
59
60
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
61
        allowed_media_domains: list[str] | None = None,
62
    ) -> None:
63
64
        """
        Args:
65
66
67
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
68
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
69
            connection: HTTP connection client to download media contents.
70
71
72
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
73
        """
74
75
        super().__init__()

76
77
78
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
79
80
81
82
83
84
85
86
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
87
88
                    f"{allowed_local_media_path_} does not exist."
                )
89
90
91
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
92
93
                    f"{allowed_local_media_path_} must be a directory."
                )
94
95
96
97
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
98
99
100
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
101
102
103
104
105

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
106
    ) -> _M:  # type: ignore[type-var]
107
108
109
110
111
112
113
114
115
116
117
118
119
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
120
    ) -> _M:  # type: ignore[type-var]
121
122
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
123
124
125
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
126

127
        filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
128
        if allowed_local_media_path not in filepath.resolve().parents:
129
            raise ValueError(
130
                f"The file path {filepath} must be a subpath "
131
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
132
            )
133

134
        return media_io.load_file(filepath)
135

136
    def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
137
138
139
140
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
141
142
143
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
144
145
                f"{url_spec.hostname}"
            )
146

147
148
149
150
151
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
152
        fetch_timeout: int | None = None,
153
    ) -> _M:  # type: ignore[type-var]
154
        url_spec = urlparse(url)
155

156
        if url_spec.scheme.startswith("http"):
157
158
            self._assert_url_in_allowed_media_domains(url_spec)

159
            connection = self.connection
160
161
162
163
164
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
165

166
            return media_io.load_bytes(data)
167

168
169
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
170

171
172
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
173

174
175
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
176

177
178
179
180
181
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
182
        fetch_timeout: int | None = None,
183
184
    ) -> _M:
        url_spec = urlparse(url)
185
        loop = asyncio.get_running_loop()
186

187
        if url_spec.scheme.startswith("http"):
188
189
            self._assert_url_in_allowed_media_domains(url_spec)

190
            connection = self.connection
191
192
193
194
195
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
196
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
197
            return await future
198

199
        if url_spec.scheme == "data":
200
201
202
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
203
            return await future
204

205
        if url_spec.scheme == "file":
206
207
208
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
209
            return await future
210
211
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
212

213
214
215
    def fetch_audio(
        self,
        audio_url: str,
216
    ) -> tuple[np.ndarray, int | float]:
217
218
219
        """
        Load audio from a URL.
        """
220
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
221

222
        return self.load_from_url(
223
            audio_url,
224
225
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
226
        )
227

228
229
230
    async def fetch_audio_async(
        self,
        audio_url: str,
231
    ) -> tuple[np.ndarray, int | float]:
232
233
234
        """
        Asynchronously fetch audio from a URL.
        """
235
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
236

237
        return await self.load_from_url_async(
238
            audio_url,
239
240
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
241
        )
242

243
244
245
246
247
248
249
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
250
        Load a PIL image from an HTTP or base64 data URL.
251

252
253
        By default, the image is converted into RGB format.
        """
254
255
256
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
257

258
259
260
261
262
263
264
265
266
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
267

268
269
    async def fetch_image_async(
        self,
270
271
        image_url: str,
        *,
272
273
274
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
275
        Asynchronously load a PIL image from an HTTP or base64 data URL.
276

277
278
        By default, the image is converted into RGB format.
        """
279
280
281
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
282

283
284
285
286
287
288
289
290
291
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
292

293
294
295
296
297
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
298
    ) -> tuple[npt.NDArray, dict[str, Any]]:
299
        """
300
        Load video from an HTTP or base64 data URL.
301
        """
302
303
304
305
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
306
307
308
309
310
311

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
312

313
314
315
    async def fetch_video_async(
        self,
        video_url: str,
316
        *,
317
        image_mode: str = "RGB",
318
    ) -> tuple[npt.NDArray, dict[str, Any]]:
319
        """
320
        Asynchronously load video from an HTTP or base64 data URL.
321
322
323

        By default, the image is converted into RGB format.
        """
324
325
326
327
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
328
329
330
331
332
333

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
334

335
336
337
338
339
340
341
342
343
344
345
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

346
347
348
349
350
351
352
353
354
355
356
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

357

358
359
def encode_audio_base64(
    audio: np.ndarray,
360
    sampling_rate: int,
361
362
) -> str:
    """Encode audio as base64."""
363
364
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
365
366


367
368
369
370
371
372
373
374
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
375

376
377
    By default, the image is converted into RGB format before being encoded.
    """
378
379
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
380
381


382
def encode_video_base64(frames: npt.NDArray) -> str:
383
384
385
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
386
387


388
def argsort_mm_positions(
389
390
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
391
392
393
394
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
395
396

    Returns:
397
398
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
399
    """
400
401
402
403
404
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
405

406
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
407

408
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
409
410


411
412
413
414
415
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
416
    merge_by_field_config: bool | None = None,
417
    multimodal_cpu_fields: Set[str] = frozenset(),
418
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
419
420
421
422
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
423
424
425
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
426
427
428
429

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
430
431
432
433
    if merge_by_field_config is None:
        raise RuntimeError(
            "`group_mm_kwargs_by_modality` now requires "
            "`merge_by_field_config` arg, please update your model runner "
434
435
            "according to https://github.com/vllm-project/vllm/pull/25676."
        )
436
437
438
439
440
441
442
443
444
    if merge_by_field_config is False:
        logger.warning_once(
            "The legacy code for batching multi-modal kwargs is deprecated and "
            "will be removed in v0.12. Please update your model with "
            "`merge_by_field_config=True` to use the new code defined by "
            "`MultiModalFieldConfig`. You can refer to "
            "https://github.com/vllm-project/vllm/issues/26149 "
            "for some examples on how to do this."
        )
445

446
    from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
447
448
449
450

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)

451
452
453
        if merge_by_field_config:
            mm_kwargs_group: BatchedTensorInputs = dict(
                MultiModalKwargsItems.from_seq(items_lst).get_data(
454
455
456
                    pin_memory=pin_memory
                )
            )
457
458

            if device is not None:
459
460
461
462
463
464
465
466
467
468
469
                mm_kwargs_group = {
                    k: json_map_leaves(
                        lambda x: x.to(device=device, non_blocking=True)
                        if isinstance(x, torch.Tensor)
                        else x,
                        v,
                    )
                    if k not in multimodal_cpu_fields
                    else v
                    for k, v in mm_kwargs_group.items()
                }
470
471
472
473
474
475
476
477
478
479
480
        else:
            mm_kwargs_group = MultiModalKwargs.as_kwargs(
                MultiModalKwargs.batch(
                    [
                        MultiModalKwargsItems.from_seq([item]).get_data()
                        for item in items_lst
                    ],
                    pin_memory=pin_memory,
                ),
                device=device,
            )
481
482
483
484

        yield modality, len(items_lst), mm_kwargs_group


485
486
def fetch_audio(
    audio_url: str,
487
488
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
489
490
491
492
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
493
494
495
496

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
497
    """
498
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
499
500
501
502
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
503
504
505
506
507
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
508
    image_io_kwargs: dict[str, Any] | None = None,
509
510
511
512
513
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
514
515
516
517

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
518
    """
519
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
520
521
522
523
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
524
525
526
527
528
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
529
    video_io_kwargs: dict[str, Any] | None = None,
530
531
532
533
534
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
535
536
537
538

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
539
    """
540
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
541
542
543
544
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
545
    return media_connector.fetch_video(video_url)