utils.py 16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from itertools import groupby
5
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
7
from urllib.parse import ParseResult, urlparse
8

9
import numpy as np
10
import numpy.typing as npt
11
import torch
12
from PIL import Image, UnidentifiedImageError
13

14
import vllm.envs as envs
15
from vllm.connections import HTTPConnection, global_http_connection
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
19

20
21
from .audio import AudioMediaIO
from .base import MediaIO
22
from .image import ImageEmbeddingMediaIO, ImageMediaIO
23
24
from .inputs import PlaceholderRange
from .video import VideoMediaIO
25

26
_M = TypeVar("_M")
27

28
29
if TYPE_CHECKING:
    from .hasher import MultiModalHashDict
30
    from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
31
32
33
34
else:
    MultiModalHashDict = Any
    MultiModalKwargs = Any
    MultiModalPlaceholderDict = Any
35

36

37
class MediaConnector:
38

39
40
    def __init__(
        self,
41
        media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
42
43
44
45
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
    ) -> None:
46
47
48
49
50
        """
        Args:
            media_io_kwargs: Additional args passed to process media 
                             inputs, keyed by modalities. For example, 
                             to set num_frames for video, set 
51
                             `--media-io-kwargs '{"video":{"num_frames":40}}'`
52
            connection: HTTP connection client to download media contents.
53
54
            allowed_local_media_path: A local directory to load media files
                                      from.
55
        """
56
57
        super().__init__()

58
59
        self.media_io_kwargs: dict[str, dict[
            str, Any]] = media_io_kwargs if media_io_kwargs else {}
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} does not exist.")
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} must be a directory.")
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
            raise RuntimeError("Cannot load local files without "
                               "`--allowed-local-media-path`.")

        filepath = Path(url_spec.path)
        if allowed_local_media_path not in filepath.resolve().parents:
104
            raise ValueError(
105
106
                f"The file path {filepath} must be a subpath "
                f"of `--allowed-local-media-path` {allowed_local_media_path}.")
107

108
        return media_io.load_file(filepath)
109

110
111
112
113
114
115
116
117
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
118

119
120
121
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = connection.get_bytes(url, timeout=fetch_timeout)
122

123
            return media_io.load_bytes(data)
124

125
126
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
127

128
129
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
130

131
132
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
133

134
135
136
137
138
139
140
141
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
142

143
144
145
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = await connection.async_get_bytes(url, timeout=fetch_timeout)
146

147
            return media_io.load_bytes(data)
148

149
150
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
151

152
153
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
154

155
156
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
157

158
159
160
161
162
163
164
    def fetch_audio(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Load audio from a URL.
        """
165
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
166

167
        return self.load_from_url(
168
            audio_url,
169
170
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
171
        )
172

173
174
175
176
177
178
179
    async def fetch_audio_async(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Asynchronously fetch audio from a URL.
        """
180
        audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
181

182
        return await self.load_from_url_async(
183
            audio_url,
184
185
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
186
        )
187

188
189
190
191
192
193
194
195
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Load a PIL image from a HTTP or base64 data URL.
196

197
198
        By default, the image is converted into RGB format.
        """
199
200
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
201

202
203
204
205
206
207
208
209
210
        try:
            return self.load_from_url(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
211

212
213
    async def fetch_image_async(
        self,
214
215
        image_url: str,
        *,
216
217
218
219
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Asynchronously load a PIL image from a HTTP or base64 data URL.
220

221
222
        By default, the image is converted into RGB format.
        """
223
224
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
225

226
227
228
229
230
231
232
233
234
        try:
            return await self.load_from_url_async(
                image_url,
                image_io,
                fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
            )
        except UnidentifiedImageError as e:
            # convert to ValueError to be properly caught upstream
            raise ValueError(str(e)) from e
235

236
237
238
239
240
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
241
    ) -> tuple[npt.NDArray, dict[str, Any]]:
242
243
244
        """
        Load video from a HTTP or base64 data URL.
        """
245
246
247
248
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
        video_io = VideoMediaIO(image_io,
                                **self.media_io_kwargs.get("video", {}))
249
250
251
252
253
254

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
255

256
257
258
    async def fetch_video_async(
        self,
        video_url: str,
259
        *,
260
        image_mode: str = "RGB",
261
    ) -> tuple[npt.NDArray, dict[str, Any]]:
262
263
264
265
266
        """
        Asynchronously load video from a HTTP or base64 data URL.

        By default, the image is converted into RGB format.
        """
267
268
269
270
        image_io = ImageMediaIO(image_mode=image_mode,
                                **self.media_io_kwargs.get("image", {}))
        video_io = VideoMediaIO(image_io,
                                **self.media_io_kwargs.get("video", {}))
271
272
273
274
275
276

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
277

278
279
280
281
282
283
284
285
286
287
288
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

289

290
291
def encode_audio_base64(
    audio: np.ndarray,
292
    sampling_rate: float,
293
294
) -> str:
    """Encode audio as base64."""
295
296
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
297
298


299
300
301
302
303
304
305
306
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
307

308
309
    By default, the image is converted into RGB format before being encoded.
    """
310
311
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
312
313


314
def encode_video_base64(frames: npt.NDArray) -> str:
315
316
317
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
318
319


320
def merge_and_sort_multimodal_metadata(
321
322
    mm_positions: MultiModalPlaceholderDict,
    mm_hashes: Optional[MultiModalHashDict],
323
324
325
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
    """Given a MultiModalPlaceholderDict, merge all PlaceholderRange
    objects from all available modalities into a single list of 
326
    PlaceholderRange, sorted by their offset (starting index in the input
327
328
    sequence) in the ascending order.

329
    Optionally if a `MultiModalHashDict` is given, same operation will be
330
331
332
    applied to the object and the sorted list of hashes will be returned.
    
    Returns:
333
334
        list[str]: List of item modalities in order of their positions in the
        input sequence.
335
        list[PlaceholderRange]: Sorted list of all PlaceholderRanges from
336
337
338
        mm_positions.
        Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
        None otherwise.
339
340
341
342
343
344
345
346
347
    """

    modalities = list(mm_positions.keys())

    assert len(modalities) > 0, "No modalities found in the mm_positions."

    # For single modality, placeholder ranges and hashes are already sorted
    # so we can return the list directly.
    if len(modalities) == 1:
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        modality = modalities[0]
        placeholder_list = list(mm_positions[modality])

        return [modality] * len(
            placeholder_list
        ), placeholder_list, None if not mm_hashes else mm_hashes[modality]

    # Create a list of (modality, placeholder, hash) tuples for all placeholders
    all_items = []
    for modality in modalities:
        placeholder_list = list(mm_positions[modality])
        hash_list: list[Optional[str]] = list(
            mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
                None
            ] * len(placeholder_list)

        for placeholder, hash_value in zip(placeholder_list, hash_list):
            all_items.append((modality, placeholder, hash_value))

    # Sort all items by offset
368
    all_items.sort(key=lambda x: x[1].offset)
369
370
371
372
373
374

    # Split into separate lists
    sorted_modalities = [item[0] for item in all_items]
    merged_placeholders = [item[1] for item in all_items]
    merged_hashes = [str(item[2])
                     for item in all_items] if mm_hashes is not None else None
375
376

    return sorted_modalities, merged_placeholders, merged_hashes
377
378
379


def group_mm_inputs_by_modality(
380
381
382
        mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
    """Group consecutive MultiModalKwargs from mm_inputs with the same modality
    together into the same list for batching purpose. For MultiModalKwargs with
383
384
385
386
387
388
    multiple modalities, put them into their own list.

    Args:
        mm_inputs: List of MultiModalKwargs.

    Returns:
389
390
391
        list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
        `MultiModalKwargs`, each inner list contains consecutive
        `MultiModalKwargs` with same modality.
392
393
394
395
    """
    if not mm_inputs:
        return []

396
    def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
397
398
399
400
401
        # If the input has multiple modalities, return a id as the unique key
        # for the mm_input input.
        if len(mm_input.modalities) > 1:
            return id(mm_input)

402
403
404
405
406
407
408
        elif len(mm_input.modalities) == 1:
            return list(mm_input.modalities)[0]

        # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
        # this is used to make InternVL with legacy pipeline still work with v1.
        else:
            return ""
409
410
411
412

    return [
        list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
    ]
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444


def run_dp_sharded_vision_model(image_input: torch.Tensor,
                                vision_model: torch.nn.Module) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function 
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.

    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
    pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
    image_input_per_rank = image_input_padded[rank *
                                              num_chunks_per_rank:(rank + 1) *
                                              num_chunks_per_rank, ...]

    vision_embeddings = vision_model(image_input_per_rank)
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
                                                         dim=0)
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492


def fetch_audio(
    audio_url: str,
    audio_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, Union[int, float]]:
    """
    Args:
        audio_url: URL of the audio file to fetch.
        audio_io_kwargs: Additional kwargs passed to handle audio IO.
    """
    media_io_kwargs = None if not audio_io_kwargs else {
        "audio": audio_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_audio(audio_url)


def fetch_image(
    image_url: str,
    image_io_kwargs: Optional[dict[str, Any]] = None,
) -> Image.Image:
    """
    Args:
        image_url: URL of the image file to fetch.
        image_io_kwargs: Additional kwargs passed to handle image IO.
    """
    media_io_kwargs = None if not image_io_kwargs else {
        "image": image_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_image(image_url)


def fetch_video(
    video_url: str,
    video_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[npt.NDArray, dict[str, Any]]:
    """
    Args:
        video_url: URL of the video file to fetch.
        video_io_kwargs: Additional kwargs passed to handle video IO.
    """
    media_io_kwargs = None if not video_io_kwargs else {
        "video": video_io_kwargs
    }
    media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
    return media_connector.fetch_video(video_url)