"vllm/vscode:/vscode.git/clone" did not exist on "5c765aec65d0f978cc2ad42164a5da2d3e0cf071"
utils.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
import mimetypes
7
from collections.abc import Generator
8
from concurrent.futures import ThreadPoolExecutor
9
from itertools import groupby
10
from pathlib import Path
11
from typing import TYPE_CHECKING, Any, TypeVar
12
from urllib.request import url2pathname
13

14
import numpy as np
15
import numpy.typing as npt
16
import torch
17
from PIL import Image, UnidentifiedImageError
18
from urllib3.util import Url, parse_url
19

20
import vllm.envs as envs
21
from vllm.connections import HTTPConnection, global_http_connection
22
from vllm.logger import init_logger
23
from vllm.utils.registry import ExtensionManager
24

25
26
27
28
29
30
31
32
from .media import (
    AudioEmbeddingMediaIO,
    AudioMediaIO,
    ImageEmbeddingMediaIO,
    ImageMediaIO,
    MediaIO,
    VideoMediaIO,
)
33

34
if TYPE_CHECKING:
35
36
37
38
39
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
40
else:
41
42
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
43
    MultiModalPlaceholderDict = Any
44

45
46
logger = init_logger(__name__)

47
global_thread_pool = ThreadPoolExecutor(
48
49
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
50
51
atexit.register(global_thread_pool.shutdown)

52
53
_M = TypeVar("_M")

54
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
55

56
57

@MEDIA_CONNECTOR_REGISTRY.register("http")
58
59
60
class MediaConnector:
    def __init__(
        self,
61
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
62
63
64
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
65
        allowed_media_domains: list[str] | None = None,
66
    ) -> None:
67
68
        """
        Args:
69
70
71
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
72
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
73
            connection: HTTP connection client to download media contents.
74
75
76
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
77
        """
78
79
        super().__init__()

80
81
82
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
83
84
85
86
87
88
89
90
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
91
92
                    f"{allowed_local_media_path_} does not exist."
                )
93
94
95
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
96
97
                    f"{allowed_local_media_path_} must be a directory."
                )
98
99
100
101
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
102
103
104
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
105
106
107

    def _load_data_url(
        self,
108
        url_spec: Url,
109
        media_io: MediaIO[_M],
110
    ) -> _M:  # type: ignore[type-var]
111
112
        url_spec_path = url_spec.path or ""
        data_spec, data = url_spec_path.split(",", 1)
113
        media_type, data_type = data_spec.split(";", 1)
114
115
        # media_type starts with a leading "/" (e.g., "/video/jpeg")
        media_type = media_type.lstrip("/")
116
117
118
119
120
121
122
123
124

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
125
        url_spec: Url,
126
        media_io: MediaIO[_M],
127
    ) -> _M:  # type: ignore[type-var]
128
129
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
130
131
132
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
133

134
135
136
        url_spec_path = url_spec.path or ""
        url_spec_netloc = url_spec.netloc or ""
        filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
137
        if allowed_local_media_path not in filepath.resolve().parents:
138
            raise ValueError(
139
                f"The file path {filepath} must be a subpath "
140
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
141
            )
142

143
        return media_io.load_file(filepath)
144

145
    def _assert_url_in_allowed_media_domains(self, url_spec: Url) -> None:
146
147
148
149
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
150
151
152
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
153
154
                f"{url_spec.hostname}"
            )
155

156
157
158
159
160
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
161
        fetch_timeout: int | None = None,
162
    ) -> _M:  # type: ignore[type-var]
163
        url_spec = parse_url(url)
164

165
        if url_spec.scheme and url_spec.scheme.startswith("http"):
166
167
            self._assert_url_in_allowed_media_domains(url_spec)

168
            connection = self.connection
169
170
171
172
173
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
174

175
            return media_io.load_bytes(data)
176

177
178
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
179

180
181
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
182

183
184
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
185

186
187
188
189
190
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
191
        fetch_timeout: int | None = None,
192
    ) -> _M:
193
        url_spec = parse_url(url)
194
        loop = asyncio.get_running_loop()
195

196
        if url_spec.scheme and url_spec.scheme.startswith("http"):
197
198
            self._assert_url_in_allowed_media_domains(url_spec)

199
            connection = self.connection
200
201
202
203
204
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
205
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
206
            return await future
207

208
        if url_spec.scheme == "data":
209
210
211
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
212
            return await future
213

214
        if url_spec.scheme == "file":
215
216
217
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
218
            return await future
219
220
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
221

222
223
224
    def fetch_audio(
        self,
        audio_url: str,
225
    ) -> tuple[np.ndarray, int | float]:
226
227
228
        """
        Load audio from a URL.
        """
229
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
230

231
        return self.load_from_url(
232
            audio_url,
233
234
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
235
        )
236

237
238
239
    async def fetch_audio_async(
        self,
        audio_url: str,
240
    ) -> tuple[np.ndarray, int | float]:
241
242
243
        """
        Asynchronously fetch audio from a URL.
        """
244
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
245

246
        return await self.load_from_url_async(
247
            audio_url,
248
249
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
250
        )
251

252
253
254
255
256
257
258
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
259
        Load a PIL image from an HTTP or base64 data URL.
260

261
262
        By default, the image is converted into RGB format.
        """
263
264
265
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
266

267
268
269
270
271
272
273
274
275
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
276

277
278
    async def fetch_image_async(
        self,
279
280
        image_url: str,
        *,
281
282
283
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
284
        Asynchronously load a PIL image from an HTTP or base64 data URL.
285

286
287
        By default, the image is converted into RGB format.
        """
288
289
290
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
291

292
293
294
295
296
297
298
299
300
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
301

302
303
304
305
306
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
307
    ) -> tuple[npt.NDArray, dict[str, Any]]:
308
        """
309
        Load video from an HTTP or base64 data URL.
310
        """
311
312
313
314
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
315
316
317
318
319
320

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
321

322
323
324
    async def fetch_video_async(
        self,
        video_url: str,
325
        *,
326
        image_mode: str = "RGB",
327
    ) -> tuple[npt.NDArray, dict[str, Any]]:
328
        """
329
        Asynchronously load video from an HTTP or base64 data URL.
330
331
332

        By default, the image is converted into RGB format.
        """
333
334
335
336
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
337
338
339
340
341
342

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
343

344
345
346
347
348
349
350
351
352
353
354
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

355
356
357
358
359
360
361
362
363
364
365
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

366

367
368
def encode_audio_base64(
    audio: np.ndarray,
369
    sampling_rate: int,
370
371
    *,
    format: str = "WAV",
372
373
) -> str:
    """Encode audio as base64."""
374
    audio_io = AudioMediaIO()
375
376
377
378
379
380
381
382
383
384
385
386
387
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
388
389


390
391
392
393
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
394
    format: str | None = None,
395
396
397
) -> str:
    """
    Encode a pillow image to base64 format.
398

399
400
    By default, the image is converted into RGB format before being encoded.
    """
401
402
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
403
404


405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
426
427
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
444
445


446
def argsort_mm_positions(
447
448
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
449
450
451
452
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
453
454

    Returns:
455
456
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
457
    """
458
459
460
461
462
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
463

464
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
465

466
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
467
468


469
470
471
472
473
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
474
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
475
476
477
478
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
479
480
481
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
482
483
484
485

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
486
    from vllm.multimodal.inputs import MultiModalKwargsItems
487
488
489

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)
490
491
492
493
494
        mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
495

496
        yield modality, len(items_lst), mm_kwargs_data
497
498


499
500
def fetch_audio(
    audio_url: str,
501
502
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
503
504
505
506
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
507
508
509
510

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
511
    """
512
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
513
514
515
516
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
517
518
519
520
521
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
522
    image_io_kwargs: dict[str, Any] | None = None,
523
524
525
526
527
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
528
529
530
531

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
532
    """
533
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
534
535
536
537
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
538
539
540
541
542
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
543
    video_io_kwargs: dict[str, Any] | None = None,
544
545
546
547
548
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
549
550
551
552

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
553
    """
554
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
555
556
557
558
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
559
    return media_connector.fetch_video(video_url)