"Dockerfile.arm" did not exist on "45ac4ff270b267765457159c0b75e1bb7ebf6d79"
utils.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
import mimetypes
7
from collections.abc import Generator, Set
8
from concurrent.futures import ThreadPoolExecutor
9
from itertools import groupby
10
from pathlib import Path
11
from typing import TYPE_CHECKING, Any, TypeVar
12
from urllib.parse import ParseResult, urlparse
13
from urllib.request import url2pathname
14

15
import numpy as np
16
import numpy.typing as npt
17
import torch
18
from PIL import Image, UnidentifiedImageError
19

20
import vllm.envs as envs
21
from vllm.connections import HTTPConnection, global_http_connection
22
from vllm.logger import init_logger
23
from vllm.utils.registry import ExtensionManager
24

25
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
26
from .base import MediaIO
27
from .image import ImageEmbeddingMediaIO, ImageMediaIO
28
from .video import VideoMediaIO
29

30
if TYPE_CHECKING:
31
32
33
34
35
    from .inputs import (
        BatchedTensorInputs,
        MultiModalKwargsItem,
        MultiModalPlaceholderDict,
    )
36
else:
37
38
    BatchedTensorInputs = Any
    MultiModalKwargsItem = Any
39
    MultiModalPlaceholderDict = Any
40

41
42
logger = init_logger(__name__)

43
global_thread_pool = ThreadPoolExecutor(
44
45
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
)
46
47
atexit.register(global_thread_pool.shutdown)

48
49
_M = TypeVar("_M")

50
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
51

52
53

@MEDIA_CONNECTOR_REGISTRY.register("http")
54
55
56
class MediaConnector:
    def __init__(
        self,
57
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
58
59
60
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
61
        allowed_media_domains: list[str] | None = None,
62
    ) -> None:
63
64
        """
        Args:
65
66
67
            media_io_kwargs: Additional args passed to process media
                             inputs, keyed by modalities. For example,
                             to set num_frames for video, set
68
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
69
            connection: HTTP connection client to download media contents.
70
71
72
            allowed_local_media_path: A local directory to load media files from.
            allowed_media_domains: If set, only media URLs that belong to this
                                   domain can be used for multi-modal inputs.
73
        """
74
75
        super().__init__()

76
77
78
        self.media_io_kwargs: dict[str, dict[str, Any]] = (
            media_io_kwargs if media_io_kwargs else {}
        )
79
80
81
82
83
84
85
86
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
87
88
                    f"{allowed_local_media_path_} does not exist."
                )
89
90
91
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
92
93
                    f"{allowed_local_media_path_} must be a directory."
                )
94
95
96
97
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_
98
99
100
        if allowed_media_domains is None:
            allowed_media_domains = []
        self.allowed_media_domains = allowed_media_domains
101
102
103
104
105

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
106
    ) -> _M:  # type: ignore[type-var]
107
108
109
110
111
112
113
114
115
116
117
118
119
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
120
    ) -> _M:  # type: ignore[type-var]
121
122
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
123
124
125
            raise RuntimeError(
                "Cannot load local files without `--allowed-local-media-path`."
            )
126

127
        filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
128
        if allowed_local_media_path not in filepath.resolve().parents:
129
            raise ValueError(
130
                f"The file path {filepath} must be a subpath "
131
                f"of `--allowed-local-media-path {allowed_local_media_path}`."
132
            )
133

134
        return media_io.load_file(filepath)
135

136
    def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
137
138
139
140
        if (
            self.allowed_media_domains
            and url_spec.hostname not in self.allowed_media_domains
        ):
141
142
143
            raise ValueError(
                f"The URL must be from one of the allowed domains: "
                f"{self.allowed_media_domains}. Input URL domain: "
144
145
                f"{url_spec.hostname}"
            )
146

147
148
149
150
151
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
152
        fetch_timeout: int | None = None,
153
    ) -> _M:  # type: ignore[type-var]
154
        url_spec = urlparse(url)
155

156
        if url_spec.scheme.startswith("http"):
157
158
            self._assert_url_in_allowed_media_domains(url_spec)

159
            connection = self.connection
160
161
162
163
164
            data = connection.get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
165

166
            return media_io.load_bytes(data)
167

168
169
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
170

171
172
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
173

174
175
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
176

177
178
179
180
181
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
182
        fetch_timeout: int | None = None,
183
184
    ) -> _M:
        url_spec = urlparse(url)
185
        loop = asyncio.get_running_loop()
186

187
        if url_spec.scheme.startswith("http"):
188
189
            self._assert_url_in_allowed_media_domains(url_spec)

190
            connection = self.connection
191
192
193
194
195
            data = await connection.async_get_bytes(
                url,
                timeout=fetch_timeout,
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
            )
196
            future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
197
            return await future
198

199
        if url_spec.scheme == "data":
200
201
202
            future = loop.run_in_executor(
                global_thread_pool, self._load_data_url, url_spec, media_io
            )
203
            return await future
204

205
        if url_spec.scheme == "file":
206
207
208
            future = loop.run_in_executor(
                global_thread_pool, self._load_file_url, url_spec, media_io
            )
209
            return await future
210
211
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
212

213
214
215
    def fetch_audio(
        self,
        audio_url: str,
216
    ) -> tuple[np.ndarray, int | float]:
217
218
219
        """
        Load audio from a URL.
        """
220
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
221

222
        return self.load_from_url(
223
            audio_url,
224
225
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
226
        )
227

228
229
230
    async def fetch_audio_async(
        self,
        audio_url: str,
231
    ) -> tuple[np.ndarray, int | float]:
232
233
234
        """
        Asynchronously fetch audio from a URL.
        """
235
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
236

237
        return await self.load_from_url_async(
238
            audio_url,
239
240
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
241
        )
242

243
244
245
246
247
248
249
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
250
        Load a PIL image from an HTTP or base64 data URL.
251

252
253
        By default, the image is converted into RGB format.
        """
254
255
256
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
257

258
259
260
261
262
263
264
265
266
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
267

268
269
    async def fetch_image_async(
        self,
270
271
        image_url: str,
        *,
272
273
274
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
275
        Asynchronously load a PIL image from an HTTP or base64 data URL.
276

277
278
        By default, the image is converted into RGB format.
        """
279
280
281
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
282

283
284
285
286
287
288
289
290
291
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
292

293
294
295
296
297
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
298
    ) -> tuple[npt.NDArray, dict[str, Any]]:
299
        """
300
        Load video from an HTTP or base64 data URL.
301
        """
302
303
304
305
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
306
307
308
309
310
311

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
312

313
314
315
    async def fetch_video_async(
        self,
        video_url: str,
316
        *,
317
        image_mode: str = "RGB",
318
    ) -> tuple[npt.NDArray, dict[str, Any]]:
319
        """
320
        Asynchronously load video from an HTTP or base64 data URL.
321
322
323

        By default, the image is converted into RGB format.
        """
324
325
326
327
        image_io = ImageMediaIO(
            image_mode=image_mode, **self.media_io_kwargs.get("image", {})
        )
        video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
328
329
330
331
332
333

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
334

335
336
337
338
339
340
341
342
343
344
345
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

346
347
348
349
350
351
352
353
354
355
356
    def fetch_audio_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load audio embedding from a URL.
        """
        audio_embedding_io = AudioEmbeddingMediaIO()

        return audio_embedding_io.load_base64("", data)

357

358
359
def encode_audio_base64(
    audio: np.ndarray,
360
    sampling_rate: int,
361
362
    *,
    format: str = "WAV",
363
364
) -> str:
    """Encode audio as base64."""
365
    audio_io = AudioMediaIO()
366
367
368
369
370
371
372
373
374
375
376
377
378
    return audio_io.encode_base64((audio, sampling_rate), audio_format=format)


def encode_audio_url(
    audio: np.ndarray,
    sampling_rate: int,
    *,
    format: str = "WAV",
) -> str:
    """Encode audio as a data URL."""
    audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
    return f"data:{mimetype};base64,{audio_b64}"
379
380


381
382
383
384
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
385
    format: str | None = None,
386
387
388
) -> str:
    """
    Encode a pillow image to base64 format.
389

390
391
    By default, the image is converted into RGB format before being encoded.
    """
392
393
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
394
395


396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def encode_image_url(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "PNG",
) -> str:
    """
    Encode a pillow image as a data URL.

    By default, the image is converted into RGB format before being encoded.
    """
    image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
    mimetype = mimetypes.types_map.get("." + format.lower(), "image")
    return f"data:{mimetype};base64,{image_b64}"


def encode_video_base64(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
417
418
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    return video_io.encode_base64(frames, video_format=format)


def encode_video_url(
    frames: npt.NDArray,
    *,
    format: str = "JPEG",
) -> str:
    video_b64 = encode_video_base64(frames, format=format)

    if format.lower() == "jpeg":
        mimetype = "video/jpeg"
    else:
        mimetype = mimetypes.types_map.get("." + format.lower(), "video")

    return f"data:{mimetype};base64,{video_b64}"
435
436


437
def argsort_mm_positions(
438
439
    mm_positions: MultiModalPlaceholderDict,
) -> list[tuple[str, int]]:
440
441
442
443
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
444
445

    Returns:
446
447
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
448
    """
449
450
451
452
453
    flat_items = (
        (modality, idx, item)
        for modality, items in mm_positions.items()
        for idx, item in enumerate(items)
    )
454

455
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
456

457
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
458
459


460
461
462
463
464
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
465
    merge_by_field_config: bool | None = None,
466
    multimodal_cpu_fields: Set[str] | None = None,
467
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
468
469
470
471
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
472
473
474
        mm_kwargs: List of `MultiModalKwargsItem`.
        device: The device to place the grouped tensors on.
        pin_memory: Whether to pin memory for faster host-to-device transfer.
475
476
477
478

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
479
    if merge_by_field_config is not None:
480
        logger.warning_once(
481
            "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
482
            "is deprecated and will be removed in v0.14."
483
        )
484
485
486
    if multimodal_cpu_fields is not None:
        logger.warning_once(
            "The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
487
            "is deprecated and will be removed in v0.14."
488
        )
489

490
    from vllm.multimodal.inputs import MultiModalKwargsItems
491
492
493

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)
494
495
496
497
498
        mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
        mm_kwargs_data = mm_kwargs_items.get_data(
            device=device,
            pin_memory=pin_memory,
        )
499

500
        yield modality, len(items_lst), mm_kwargs_data
501
502


503
504
def fetch_audio(
    audio_url: str,
505
506
    audio_io_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, int | float]:
507
508
509
510
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
511
512
513
514

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
515
    """
516
    media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
517
518
519
520
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
521
522
523
524
525
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
526
    image_io_kwargs: dict[str, Any] | None = None,
527
528
529
530
531
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
532
533
534
535

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
536
    """
537
    media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
538
539
540
541
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
542
543
544
545
546
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
547
    video_io_kwargs: dict[str, Any] | None = None,
548
549
550
551
552
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
553
554
555
556

    Warning:
        This method has direct access to local files and is only intended
        to be called by user code. Never call this from the online server!
557
    """
558
    media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
559
560
561
562
    media_connector = MediaConnector(
        media_io_kwargs=media_io_kwargs,
        allowed_local_media_path="/",
    )
563
    return media_connector.fetch_video(video_url)