"vllm/vscode:/vscode.git/clone" did not exist on "fc6acc88caac881a92d84bcef5bc734d67035786"
utils.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
import mimetypes
7
from collections.abc import Generator
8
from concurrent.futures import ThreadPoolExecutor
9
from itertools import groupby
10
from pathlib import Path
11
from typing import TYPE_CHECKING, Any, TypeVar
12
from urllib.request import url2pathname
13

14
import numpy as np
15
import numpy.typing as npt
16
import torch
17
from PIL import Image, UnidentifiedImageError
18
from urllib3.util import Url, parse_url
19

20
import vllm.envs as envs
21
from vllm.connections import HTTPConnection, global_http_connection
22
from vllm.logger import init_logger
23
from vllm.utils.registry import ExtensionManager
24

25
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
26
from .base import MediaIO
27
from .image import ImageEmbeddingMediaIO, ImageMediaIO
28
from .video import VideoMediaIO
29

30
if TYPE_CHECKING:
31
32
33
34
35
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
36
else:
37
38
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
39
    MultiModalPlaceholderDict = Any
40

41
42
logger = init_logger(__name__)

43
global_thread_pool = ThreadPoolExecutor(
44
45
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
46
47
atexit.register(global_thread_pool.shutdown)

48
49
_M = TypeVar("_M")

50
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
51

52
53

@MEDIA_CONNECTOR_REGISTRY.register("http")
54
55
56
class MediaConnector:
    def __init__(
        self,
57
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
58
59
60
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
61
        allowed_media_domains: list[str] | None = None,
62
    ) -> None:
63
64
        """
        Args:
65
66
67
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
68
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
69
            connection: HTTP connection client to download media contents.
70
71
72
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
73
        """
74
75
        super().__init__()

76
77
78
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
79
80
81
82
83
84
85
86
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
87
88
                    f"{allowed_local_media_path_} does not exist."
                )
89
90
91
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
92
93
                    f"{allowed_local_media_path_} must be a directory."
                )
94
95
96
97
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
98
99
100
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
101
102
103

    def _load_data_url(
        self,
104
        url_spec: Url,
105
        media_io: MediaIO[_M],
106
    ) -> _M:  # type: ignore[type-var]
107
108
        url_spec_path = url_spec.path or ""
        data_spec, data = url_spec_path.split(",", 1)
109
        media_type, data_type = data_spec.split(";", 1)
110
111
        # media_type starts with a leading "/" (e.g., "/video/jpeg")
        media_type = media_type.lstrip("/")
112
113
114
115
116
117
118
119
120

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
121
        url_spec: Url,
122
        media_io: MediaIO[_M],
123
    ) -> _M:  # type: ignore[type-var]
124
125
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
126
127
128
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
129

130
131
132
        url_spec_path = url_spec.path or ""
        url_spec_netloc = url_spec.netloc or ""
        filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
133
        if allowed_local_media_path not in filepath.resolve().parents:
134
            raise ValueError(
135
                f"The file path {filepath} must be a subpath "
136
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
137
            )
138

139
        return media_io.load_file(filepath)
140

141
    def _assert_url_in_allowed_media_domains(self, url_spec: Url) -> None:
142
143
144
145
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
146
147
148
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
149
150
                f"{url_spec.hostname}"
            )
151

152
153
154
155
156
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
157
        fetch_timeout: int | None = None,
158
    ) -> _M:  # type: ignore[type-var]
159
        url_spec = parse_url(url)
160

161
        if url_spec.scheme and url_spec.scheme.startswith("http"):
162
163
            self._assert_url_in_allowed_media_domains(url_spec)

164
            connection = self.connection
165
166
167
168
169
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
170

171
            return media_io.load_bytes(data)
172

173
174
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
175

176
177
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
178

179
180
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
181

182
183
184
185
186
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
187
        fetch_timeout: int | None = None,
188
    ) -> _M:
189
        url_spec = parse_url(url)
190
        loop = asyncio.get_running_loop()
191

192
        if url_spec.scheme and url_spec.scheme.startswith("http"):
193
194
            self._assert_url_in_allowed_media_domains(url_spec)

195
            connection = self.connection
196
197
198
199
200
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
201
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
202
            return await future
203

204
        if url_spec.scheme == "data":
205
206
207
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
208
            return await future
209

210
        if url_spec.scheme == "file":
211
212
213
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
214
            return await future
215
216
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
217

218
219
220
    def fetch_audio(
        self,
        audio_url: str,
221
    ) -> tuple[np.ndarray, int | float]:
222
223
224
        """
        Load audio from a URL.
        """
225
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
226

227
        return self.load_from_url(
228
            audio_url,
229
230
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
231
        )
232

233
234
235
    async def fetch_audio_async(
        self,
        audio_url: str,
236
    ) -> tuple[np.ndarray, int | float]:
237
238
239
        """
        Asynchronously fetch audio from a URL.
        """
240
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
241

242
        return await self.load_from_url_async(
243
            audio_url,
244
245
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
246
        )
247

248
249
250
251
252
253
254
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
255
        Load a PIL image from an HTTP or base64 data URL.
256

257
258
        By default, the image is converted into RGB format.
        """
259
260
261
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
262

263
264
265
266
267
268
269
270
271
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
272

273
274
    async def fetch_image_async(
        self,
275
276
        image_url: str,
        *,
277
278
279
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
280
        Asynchronously load a PIL image from an HTTP or base64 data URL.
281

282
283
        By default, the image is converted into RGB format.
        """
284
285
286
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
287

288
289
290
291
292
293
294
295
296
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
297

298
299
300
301
302
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
303
    ) -> tuple[npt.NDArray, dict[str, Any]]:
304
        """
305
        Load video from an HTTP or base64 data URL.
306
        """
307
308
309
310
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
311
312
313
314
315
316

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
317

318
319
320
    async def fetch_video_async(
        self,
        video_url: str,
321
        *,
322
        image_mode: str = "RGB",
323
    ) -> tuple[npt.NDArray, dict[str, Any]]:
324
        """
325
        Asynchronously load video from an HTTP or base64 data URL.
326
327
328

        By default, the image is converted into RGB format.
        """
329
330
331
332
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
333
334
335
336
337
338

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
339

340
341
342
343
344
345
346
347
348
349
350
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

351
352
353
354
355
356
357
358
359
360
361
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

362

363
364
def encode_audio_base64(
    audio: np.ndarray,
365
    sampling_rate: int,
366
367
    *,
    format: str = "WAV",
368
369
) -> str:
    """Encode audio as base64."""
370
    audio_io = AudioMediaIO()
371
372
373
374
375
376
377
378
379
380
381
382
383
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
384
385


386
387
388
389
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
390
    format: str | None = None,
391
392
393
) -> str:
    """
    Encode a pillow image to base64 format.
394

395
396
    By default, the image is converted into RGB format before being encoded.
    """
397
398
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
399
400


401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
422
423
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
440
441


442
def argsort_mm_positions(
443
444
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
445
446
447
448
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
449
450

    Returns:
451
452
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
453
    """
454
455
456
457
458
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
459

460
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
461

462
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
463
464


465
466
467
468
469
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
470
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
471
472
473
474
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
475
476
477
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
478
479
480
481

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
482
    from vllm.multimodal.inputs import MultiModalKwargsItems
483
484
485

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)
486
487
488
489
490
        mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
491

492
        yield modality, len(items_lst), mm_kwargs_data
493
494


495
496
def fetch_audio(
    audio_url: str,
497
498
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
499
500
501
502
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
503
504
505
506

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
507
    """
508
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
509
510
511
512
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
513
514
515
516
517
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
518
    image_io_kwargs: dict[str, Any] | None = None,
519
520
521
522
523
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
524
525
526
527

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
528
    """
529
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
530
531
532
533
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
534
535
536
537
538
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
539
    video_io_kwargs: dict[str, Any] | None = None,
540
541
542
543
544
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
545
546
547
548

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
549
    """
550
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
551
552
553
554
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
555
    return media_connector.fetch_video(video_url)