utils.py 16.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
import mimetypes
7
from collections.abc import Generator
8
from concurrent.futures import ThreadPoolExecutor
9
from itertools import groupby
10
from pathlib import Path
11
from typing import TYPE_CHECKING, Any, TypeVar
12
from urllib.parse import ParseResult, urlparse
13
from urllib.request import url2pathname
14

15
import numpy as np
16
import numpy.typing as npt
17
import torch
18
from PIL import Image, UnidentifiedImageError
19

20
import vllm.envs as envs
21
from vllm.connections import HTTPConnection, global_http_connection
22
from vllm.logger import init_logger
23
from vllm.utils.registry import ExtensionManager
24

25
26
27
28
29
30
31
32
from .media import (
    AudioEmbeddingMediaIO,
    AudioMediaIO,
    ImageEmbeddingMediaIO,
    ImageMediaIO,
    MediaIO,
    VideoMediaIO,
)
33

34
if TYPE_CHECKING:
35
36
37
38
39
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
40
else:
41
42
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
43
    MultiModalPlaceholderDict = Any
44

45
46
logger = init_logger(__name__)

47
global_thread_pool = ThreadPoolExecutor(
48
49
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
50
51
atexit.register(global_thread_pool.shutdown)

52
53
_M = TypeVar("_M")

54
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
55

56
57

@MEDIA_CONNECTOR_REGISTRY.register("http")
58
59
60
class MediaConnector:
    def __init__(
        self,
61
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
62
63
64
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
65
        allowed_media_domains: list[str] | None = None,
66
    ) -> None:
67
68
        """
        Args:
69
70
71
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
72
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
73
            connection: HTTP connection client to download media contents.
74
75
76
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
77
        """
78
79
        super().__init__()

80
81
82
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
83
84
85
86
87
88
89
90
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
91
92
                    f"{allowed_local_media_path_} does not exist."
                )
93
94
95
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
96
97
                    f"{allowed_local_media_path_} must be a directory."
                )
98
99
100
101
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
102
103
104
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
105
106
107
108
109

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
110
    ) -> _M:  # type: ignore[type-var]
111
112
113
114
115
116
117
118
119
120
121
122
123
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
124
    ) -> _M:  # type: ignore[type-var]
125
126
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
127
128
129
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
130

131
        filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
132
        if allowed_local_media_path not in filepath.resolve().parents:
133
            raise ValueError(
134
                f"The file path {filepath} must be a subpath "
135
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
136
            )
137

138
        return media_io.load_file(filepath)
139

140
    def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
141
142
143
144
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
145
146
147
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
148
149
                f"{url_spec.hostname}"
            )
150

151
152
153
154
155
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
156
        fetch_timeout: int | None = None,
157
    ) -> _M:  # type: ignore[type-var]
158
        url_spec = urlparse(url)
159

160
        if url_spec.scheme.startswith("http"):
161
162
            self._assert_url_in_allowed_media_domains(url_spec)

163
            connection = self.connection
164
165
166
167
168
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
169

170
            return media_io.load_bytes(data)
171

172
173
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
174

175
176
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
177

178
179
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
180

181
182
183
184
185
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
186
        fetch_timeout: int | None = None,
187
188
    ) -> _M:
        url_spec = urlparse(url)
189
        loop = asyncio.get_running_loop()
190

191
        if url_spec.scheme.startswith("http"):
192
193
            self._assert_url_in_allowed_media_domains(url_spec)

194
            connection = self.connection
195
196
197
198
199
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
200
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
201
            return await future
202

203
        if url_spec.scheme == "data":
204
205
206
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
207
            return await future
208

209
        if url_spec.scheme == "file":
210
211
212
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
213
            return await future
214
215
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
216

217
218
219
    def fetch_audio(
        self,
        audio_url: str,
220
    ) -> tuple[np.ndarray, int | float]:
221
222
223
        """
        Load audio from a URL.
        """
224
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
225

226
        return self.load_from_url(
227
            audio_url,
228
229
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
230
        )
231

232
233
234
    async def fetch_audio_async(
        self,
        audio_url: str,
235
    ) -> tuple[np.ndarray, int | float]:
236
237
238
        """
        Asynchronously fetch audio from a URL.
        """
239
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
240

241
        return await self.load_from_url_async(
242
            audio_url,
243
244
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
245
        )
246

247
248
249
250
251
252
253
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
254
        Load a PIL image from an HTTP or base64 data URL.
255

256
257
        By default, the image is converted into RGB format.
        """
258
259
260
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
261

262
263
264
265
266
267
268
269
270
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
271

272
273
    async def fetch_image_async(
        self,
274
275
        image_url: str,
        *,
276
277
278
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
279
        Asynchronously load a PIL image from an HTTP or base64 data URL.
280

281
282
        By default, the image is converted into RGB format.
        """
283
284
285
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
286

287
288
289
290
291
292
293
294
295
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
296

297
298
299
300
301
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
302
    ) -> tuple[npt.NDArray, dict[str, Any]]:
303
        """
304
        Load video from an HTTP or base64 data URL.
305
        """
306
307
308
309
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
310
311
312
313
314
315

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
316

317
318
319
    async def fetch_video_async(
        self,
        video_url: str,
320
        *,
321
        image_mode: str = "RGB",
322
    ) -> tuple[npt.NDArray, dict[str, Any]]:
323
        """
324
        Asynchronously load video from an HTTP or base64 data URL.
325
326
327

        By default, the image is converted into RGB format.
        """
328
329
330
331
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
332
333
334
335
336
337

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
338

339
340
341
342
343
344
345
346
347
348
349
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

350
351
352
353
354
355
356
357
358
359
360
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

361

362
363
def encode_audio_base64(
    audio: np.ndarray,
364
    sampling_rate: int,
365
366
    *,
    format: str = "WAV",
367
368
) -> str:
    """Encode audio as base64."""
369
    audio_io = AudioMediaIO()
370
371
372
373
374
375
376
377
378
379
380
381
382
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
383
384


385
386
387
388
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
389
    format: str | None = None,
390
391
392
) -> str:
    """
    Encode a pillow image to base64 format.
393

394
395
    By default, the image is converted into RGB format before being encoded.
    """
396
397
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
398
399


400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
421
422
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
439
440


441
def argsort_mm_positions(
442
443
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
444
445
446
447
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
448
449

    Returns:
450
451
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
452
    """
453
454
455
456
457
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
458

459
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
460

461
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
462
463


464
465
466
467
468
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
469
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
470
471
472
473
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
474
475
476
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
477
478
479
480

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
481
    from vllm.multimodal.inputs import MultiModalKwargsItems
482
483
484

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)
485
486
487
488
489
        mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
490

491
        yield modality, len(items_lst), mm_kwargs_data
492
493


494
495
def fetch_audio(
    audio_url: str,
496
497
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
498
499
500
501
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
502
503
504
505

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
506
    """
507
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
508
509
510
511
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
512
513
514
515
516
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
517
    image_io_kwargs: dict[str, Any] | None = None,
518
519
520
521
522
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
523
524
525
526

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
527
    """
528
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
529
530
531
532
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
533
534
535
536
537
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
538
    video_io_kwargs: dict[str, Any] | None = None,
539
540
541
542
543
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
544
545
546
547

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
548
    """
549
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
550
551
552
553
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
554
    return media_connector.fetch_video(video_url)