utils.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
from collections.abc import Generator, Set
7
from concurrent.futures import ThreadPoolExecutor
8
from itertools import groupby
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, TypeVar
11
from urllib.parse import ParseResult, urlparse
12
from urllib.request import url2pathname
13

14
import numpy as np
15
import numpy.typing as npt
16
import torch
17
from PIL import Image, UnidentifiedImageError
18

19
import vllm.envs as envs
20
from vllm.connections import HTTPConnection, global_http_connection
21
from vllm.logger import init_logger
22
from vllm.utils.registry import ExtensionManager
23

24
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
25
from .base import MediaIO
26
from .image import ImageEmbeddingMediaIO, ImageMediaIO
27
from .video import VideoMediaIO
28

29
if TYPE_CHECKING:
30
31
32
33
34
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
35
else:
36
37
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
38
    MultiModalPlaceholderDict = Any
39

40
41
logger = init_logger(__name__)

42
global_thread_pool = ThreadPoolExecutor(
43
44
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
45
46
atexit.register(global_thread_pool.shutdown)

47
48
_M = TypeVar("_M")

49
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
50

51
52

@MEDIA_CONNECTOR_REGISTRY.register("http")
53
54
55
class MediaConnector:
    def __init__(
        self,
56
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
57
58
59
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
60
        allowed_media_domains: list[str] | None = None,
61
    ) -> None:
62
63
        """
        Args:
64
65
66
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
67
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
68
            connection: HTTP connection client to download media contents.
69
70
71
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
72
        """
73
74
        super().__init__()

75
76
77
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
78
79
80
81
82
83
84
85
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
86
87
                    f"{allowed_local_media_path_} does not exist."
                )
88
89
90
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
91
92
                    f"{allowed_local_media_path_} must be a directory."
                )
93
94
95
96
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
97
98
99
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
100
101
102
103
104

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
105
    ) -> _M:  # type: ignore[type-var]
106
107
108
109
110
111
112
113
114
115
116
117
118
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
119
    ) -> _M:  # type: ignore[type-var]
120
121
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
122
123
124
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
125

126
        filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
127
        if allowed_local_media_path not in filepath.resolve().parents:
128
            raise ValueError(
129
                f"The file path {filepath} must be a subpath "
130
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
131
            )
132

133
        return media_io.load_file(filepath)
134

135
    def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
136
137
138
139
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
140
141
142
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
143
144
                f"{url_spec.hostname}"
            )
145

146
147
148
149
150
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
151
        fetch_timeout: int | None = None,
152
    ) -> _M:  # type: ignore[type-var]
153
        url_spec = urlparse(url)
154

155
        if url_spec.scheme.startswith("http"):
156
157
            self._assert_url_in_allowed_media_domains(url_spec)

158
            connection = self.connection
159
160
161
162
163
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
164

165
            return media_io.load_bytes(data)
166

167
168
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
169

170
171
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
172

173
174
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
175

176
177
178
179
180
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
181
        fetch_timeout: int | None = None,
182
183
    ) -> _M:
        url_spec = urlparse(url)
184
        loop = asyncio.get_running_loop()
185

186
        if url_spec.scheme.startswith("http"):
187
188
            self._assert_url_in_allowed_media_domains(url_spec)

189
            connection = self.connection
190
191
192
193
194
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
195
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
196
            return await future
197

198
        if url_spec.scheme == "data":
199
200
201
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
202
            return await future
203

204
        if url_spec.scheme == "file":
205
206
207
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
208
            return await future
209
210
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
211

212
213
214
    def fetch_audio(
        self,
        audio_url: str,
215
    ) -> tuple[np.ndarray, int | float]:
216
217
218
        """
        Load audio from a URL.
        """
219
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
220

221
        return self.load_from_url(
222
            audio_url,
223
224
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
225
        )
226

227
228
229
    async def fetch_audio_async(
        self,
        audio_url: str,
230
    ) -> tuple[np.ndarray, int | float]:
231
232
233
        """
        Asynchronously fetch audio from a URL.
        """
234
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
235

236
        return await self.load_from_url_async(
237
            audio_url,
238
239
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
240
        )
241

242
243
244
245
246
247
248
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
249
        Load a PIL image from an HTTP or base64 data URL.
250

251
252
        By default, the image is converted into RGB format.
        """
253
254
255
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
256

257
258
259
260
261
262
263
264
265
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
266

267
268
    async def fetch_image_async(
        self,
269
270
        image_url: str,
        *,
271
272
273
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
274
        Asynchronously load a PIL image from an HTTP or base64 data URL.
275

276
277
        By default, the image is converted into RGB format.
        """
278
279
280
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
281

282
283
284
285
286
287
288
289
290
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
291

292
293
294
295
296
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
297
    ) -> tuple[npt.NDArray, dict[str, Any]]:
298
        """
299
        Load video from an HTTP or base64 data URL.
300
        """
301
302
303
304
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
305
306
307
308
309
310

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
311

312
313
314
    async def fetch_video_async(
        self,
        video_url: str,
315
        *,
316
        image_mode: str = "RGB",
317
    ) -> tuple[npt.NDArray, dict[str, Any]]:
318
        """
319
        Asynchronously load video from an HTTP or base64 data URL.
320
321
322

        By default, the image is converted into RGB format.
        """
323
324
325
326
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
327
328
329
330
331
332

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
333

334
335
336
337
338
339
340
341
342
343
344
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

345
346
347
348
349
350
351
352
353
354
355
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

356

357
358
def encode_audio_base64(
    audio: np.ndarray,
359
    sampling_rate: int,
360
361
) -> str:
    """Encode audio as base64."""
362
363
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
364
365


366
367
368
369
370
371
372
373
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
374

375
376
    By default, the image is converted into RGB format before being encoded.
    """
377
378
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
379
380


381
def encode_video_base64(frames: npt.NDArray) -> str:
382
383
384
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
385
386


387
def argsort_mm_positions(
388
389
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
390
391
392
393
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
394
395

    Returns:
396
397
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
398
    """
399
400
401
402
403
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
404

405
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
406

407
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
408
409


410
411
412
413
414
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
415
    merge_by_field_config: bool | None = None,
416
    multimodal_cpu_fields: Set[str] | None = None,
417
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
418
419
420
421
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
422
423
424
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
425
426
427
428

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
429
    if merge_by_field_config is not None:
430
        logger.warning_once(
431
432
            "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
            "is deprecated and will be removed in v0.13."
433
        )
434
435
436
437
438
    if multimodal_cpu_fields is not None:
        logger.warning_once(
            "The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
            "is deprecated and will be removed in v0.13."
        )
439

440
    from vllm.multimodal.inputs import MultiModalKwargsItems
441
442
443

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)
444
445
446
447
448
        mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
449

450
        yield modality, len(items_lst), mm_kwargs_data
451
452


453
454
def fetch_audio(
    audio_url: str,
455
456
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
457
458
459
460
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
461
462
463
464

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
465
    """
466
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
467
468
469
470
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
471
472
473
474
475
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
476
    image_io_kwargs: dict[str, Any] | None = None,
477
478
479
480
481
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
482
483
484
485

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
486
    """
487
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
488
489
490
491
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
492
493
494
495
496
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
497
    video_io_kwargs: dict[str, Any] | None = None,
498
499
500
501
502
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
503
504
505
506

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
507
    """
508
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
509
510
511
512
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
513
    return media_connector.fetch_video(video_url)