utils.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import atexit
6
from collections.abc import Iterable
7
from concurrent.futures import ThreadPoolExecutor
8
from itertools import groupby
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
11
from urllib.parse import ParseResult, urlparse
12

13
import numpy as np
14
import numpy.typing as npt
15
import torch
16
from PIL import Image, UnidentifiedImageError
17
from typing_extensions import deprecated
18

19
import vllm.envs as envs
20
from vllm.connections import HTTPConnection, global_http_connection
21
22
23
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
24

25
26
from .audio import AudioMediaIO
from .base import MediaIO
27
from .image import ImageEmbeddingMediaIO, ImageMediaIO
28
from .video import VideoMediaIO
29

30
_M = TypeVar("_M")
31

32
if TYPE_CHECKING:
33
34
    from .inputs import (BatchedTensorInputs, MultiModalKwargs,
                         MultiModalKwargsItem, MultiModalPlaceholderDict)
35
else:
36
    BatchedTensorInputs = Any
37
    MultiModalKwargs = Any
38
    MultiModalKwargsItem = Any
39
    MultiModalPlaceholderDict = Any
40

41
42
43
44
global_thread_pool = ThreadPoolExecutor(
    max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
atexit.register(global_thread_pool.shutdown)

45

46
class MediaConnector:
47

48
49
    def __init__(
        self,
50
        media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
51
52
53
54
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
    ) -> None:
55
56
57
58
59
        """
        Args:
            media_io_kwargs: Additional args passed to process media 
                             inputs, keyed by modalities. For example, 
                             to set num_frames for video, set 
60
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
61
            connection: HTTP connection client to download media contents.
62
63
            allowed_local_media_path: A local directory to load media files
                                      from.
64
        """
65
66
        super().__init__()

67
68
        self.media_io_kwargs: dict[str, dict[
            str, Any]] = media_io_kwargs if media_io_kwargs else {}
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} does not exist.")
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} must be a directory.")
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
            raise RuntimeError("Cannot load local files without "
                               "`--allowed-local-media-path`.")

        filepath = Path(url_spec.path)
        if allowed_local_media_path not in filepath.resolve().parents:
113
            raise ValueError(
114
115
                f"The file path {filepath} must be a subpath "
                f"of `--allowed-local-media-path` {allowed_local_media_path}.")
116

117
        return media_io.load_file(filepath)
118

119
120
121
122
123
124
125
126
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
127

128
129
130
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = connection.get_bytes(url, timeout=fetch_timeout)
131

132
            return media_io.load_bytes(data)
133

134
135
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
136

137
138
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
139

140
141
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
142

143
144
145
146
147
148
149
150
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
151
        loop = asyncio.get_running_loop()
152

153
154
155
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = await connection.async_get_bytes(url, timeout=fetch_timeout)
156
157
158
            future = loop.run_in_executor(global_thread_pool,
                                          media_io.load_bytes, data)
            return await future
159

160
        if url_spec.scheme == "data":
161
162
163
164
            future = loop.run_in_executor(global_thread_pool,
                                          self._load_data_url, url_spec,
                                          media_io)
            return await future
165

166
        if url_spec.scheme == "file":
167
168
169
170
            future = loop.run_in_executor(global_thread_pool,
                                          self._load_file_url, url_spec,
                                          media_io)
            return await future
171
172
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
173

174
175
176
177
178
179
180
    def fetch_audio(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Load audio from a URL.
        """
181
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
182

183
        return self.load_from_url(
184
            audio_url,
185
186
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
187
        )
188

189
190
191
192
193
194
195
    async def fetch_audio_async(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Asynchronously fetch audio from a URL.
        """
196
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
197

198
        return await self.load_from_url_async(
199
            audio_url,
200
201
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
202
        )
203

204
205
206
207
208
209
210
211
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Load a PIL image from a HTTP or base64 data URL.
212

213
214
        By default, the image is converted into RGB format.
        """
215
216
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
217

218
219
220
221
222
223
224
225
226
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
227

228
229
    async def fetch_image_async(
        self,
230
231
        image_url: str,
        *,
232
233
234
235
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Asynchronously load a PIL image from a HTTP or base64 data URL.
236

237
238
        By default, the image is converted into RGB format.
        """
239
240
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
241

242
243
244
245
246
247
248
249
250
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
251

252
253
254
255
256
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
257
    ) -> tuple[npt.NDArray, dict[str, Any]]:
258
259
260
        """
        Load video from a HTTP or base64 data URL.
        """
261
262
263
264
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
        video_io = VideoMediaIO(image_io,
                                **self.media_io_kwargs.get("video", {}))
265
266
267
268
269
270

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
271

272
273
274
    async def fetch_video_async(
        self,
        video_url: str,
275
        *,
276
        image_mode: str = "RGB",
277
    ) -> tuple[npt.NDArray, dict[str, Any]]:
278
279
280
281
282
        """
        Asynchronously load video from a HTTP or base64 data URL.

        By default, the image is converted into RGB format.
        """
283
284
285
286
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
        video_io = VideoMediaIO(image_io,
                                **self.media_io_kwargs.get("video", {}))
287
288
289
290
291
292

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
293

294
295
296
297
298
299
300
301
302
303
304
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

305

306
307
def encode_audio_base64(
    audio: np.ndarray,
308
    sampling_rate: float,
309
310
) -> str:
    """Encode audio as base64."""
311
312
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
313
314


315
316
317
318
319
320
321
322
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
323

324
325
    By default, the image is converted into RGB format before being encoded.
    """
326
327
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
328
329


330
def encode_video_base64(frames: npt.NDArray) -> str:
331
332
333
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
334
335


336
337
338
339
340
341
def argsort_mm_positions(
        mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
    """
    Given a `MultiModalPlaceholderDict`, output a sequence of keys to
    sort the dictionary by `offset` (starting index in the input sequence)
    in ascending order.
342
343

    Returns:
344
345
        A list of `(modality, idx)`, which can be used to access an item
        by `mm_positions[modality][idx]`.
346
    """
347
348
349
    flat_items = ((modality, idx, item)
                  for modality, items in mm_positions.items()
                  for idx, item in enumerate(items))
350

351
    sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
352

353
    return [(modality, idx) for modality, idx, _ in sorted_flat_items]
354
355


356
357
358
359
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
            "`group_mm_kwargs_by_modality` and will be removed in v0.13. "
            "Please use `group_mm_kwargs_by_modality` instead.")
360
def group_mm_inputs_by_modality(
361
        mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
362
363
364
    if not mm_inputs:
        return []

365
    def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
366
367
368
369
370
        # If the input has multiple modalities, return a id as the unique key
        # for the mm_input input.
        if len(mm_input.modalities) > 1:
            return id(mm_input)

371
372
373
374
375
376
377
        elif len(mm_input.modalities) == 1:
            return list(mm_input.modalities)[0]

        # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
        # this is used to make InternVL with legacy pipeline still work with v1.
        else:
            return ""
378
379
380
381

    return [
        list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
    ]
382
383


384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def group_mm_kwargs_by_modality(
    mm_kwargs: list[MultiModalKwargsItem],
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
    """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
    modality together into the same `MultiModalKwargs` instance.

    Args:
        mm_inputs: List of `MultiModalKwargsItem`.

    Yields:
        A tuple `(modality, num_items, grouped_kwargs)`.
    """
    from vllm.multimodal.inputs import MultiModalKwargs

    for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
        items_lst = list(items)

        # mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
        #                                               pin_memory=pin_memory)

        # if device is not None:
        #     mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
        #                                       mm_kwargs_group.data)

        # TODO: Once V0 is removed, we can use the merging logic above
        # to avoid creating an extra batch dimension (except for fields
        # that are meant to be stacked anyway).
        # We will also need to update each model to remove `flatten_bn`.
        mm_kwargs_group = MultiModalKwargs.as_kwargs(
            MultiModalKwargs.batch(
                [MultiModalKwargs.from_items([item]) for item in items_lst],
                pin_memory=pin_memory,
            ),
            device=device,
        )

        yield modality, len(items_lst), mm_kwargs_group


426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def run_dp_sharded_vision_model(image_input: torch.Tensor,
                                vision_model: torch.nn.Module) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function 
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.

    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
    pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
    image_input_per_rank = image_input_padded[rank *
                                              num_chunks_per_rank:(rank + 1) *
                                              num_chunks_per_rank, ...]

    vision_embeddings = vision_model(image_input_per_rank)
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
                                                         dim=0)
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502


def fetch_audio(
    audio_url: str,
    audio_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, Union[int, float]]:
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
    """
    media_io_kwargs = None if not audio_io_kwargs else {
        "audio": audio_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
    image_io_kwargs: Optional[dict[str, Any]] = None,
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
    """
    media_io_kwargs = None if not image_io_kwargs else {
        "image": image_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
    video_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
    """
    media_io_kwargs = None if not video_io_kwargs else {
        "video": video_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
503
    return media_connector.fetch_video(video_url)