utils.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from itertools import groupby
5
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
7
from urllib.parse import ParseResult, urlparse
8

9
import numpy as np
10
import numpy.typing as npt
11
import torch
12
13
from PIL import Image

14
import vllm.envs as envs
15
from vllm.connections import HTTPConnection, global_http_connection
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
19

20
21
from .audio import AudioMediaIO
from .base import MediaIO
22
from .image import ImageEmbeddingMediaIO, ImageMediaIO
23
24
from .inputs import PlaceholderRange
from .video import VideoMediaIO
25

26
_M = TypeVar("_M")
27

28
29
if TYPE_CHECKING:
    from .hasher import MultiModalHashDict
30
    from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
31
32
33
34
else:
    MultiModalHashDict = Any
    MultiModalKwargs = Any
    MultiModalPlaceholderDict = Any
35

36

37
class MediaConnector:
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def __init__(
        self,
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
    ) -> None:
        super().__init__()

        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} does not exist.")
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} must be a directory.")
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
            raise RuntimeError("Cannot load local files without "
                               "`--allowed-local-media-path`.")

        filepath = Path(url_spec.path)
        if allowed_local_media_path not in filepath.resolve().parents:
91
            raise ValueError(
92
93
                f"The file path {filepath} must be a subpath "
                f"of `--allowed-local-media-path` {allowed_local_media_path}.")
94

95
        return media_io.load_file(filepath)
96

97
98
99
100
101
102
103
104
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
105

106
107
108
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = connection.get_bytes(url, timeout=fetch_timeout)
109

110
            return media_io.load_bytes(data)
111

112
113
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
114

115
116
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
117

118
119
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
120

121
122
123
124
125
126
127
128
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
129

130
131
132
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = await connection.async_get_bytes(url, timeout=fetch_timeout)
133

134
            return media_io.load_bytes(data)
135

136
137
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
138

139
140
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
141

142
143
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
144

145
146
147
148
149
150
151
152
    def fetch_audio(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Load audio from a URL.
        """
        audio_io = AudioMediaIO()
153

154
        return self.load_from_url(
155
            audio_url,
156
157
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
158
        )
159

160
161
162
163
164
165
166
167
    async def fetch_audio_async(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Asynchronously fetch audio from a URL.
        """
        audio_io = AudioMediaIO()
168

169
        return await self.load_from_url_async(
170
            audio_url,
171
172
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
173
        )
174

175
176
177
178
179
180
181
182
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Load a PIL image from a HTTP or base64 data URL.
183

184
185
186
        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
187

188
189
190
191
192
        return self.load_from_url(
            image_url,
            image_io,
            fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
        )
193

194
195
    async def fetch_image_async(
        self,
196
197
        image_url: str,
        *,
198
199
200
201
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Asynchronously load a PIL image from a HTTP or base64 data URL.
202

203
204
205
        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
206

207
208
209
210
211
        return await self.load_from_url_async(
            image_url,
            image_io,
            fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
        )
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
        num_frames: int = 32,
    ) -> npt.NDArray:
        """
        Load video from a HTTP or base64 data URL.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
        video_io = VideoMediaIO(image_io, num_frames=num_frames)

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
231

232
233
234
    async def fetch_video_async(
        self,
        video_url: str,
235
        *,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        image_mode: str = "RGB",
        num_frames: int = 32,
    ) -> npt.NDArray:
        """
        Asynchronously load video from a HTTP or base64 data URL.

        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
        video_io = VideoMediaIO(image_io, num_frames=num_frames)

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
252

253
254
255
256
257
258
259
260
261
262
263
    def fetch_image_embedding(
        self,
        data: str,
    ) -> torch.Tensor:
        """
        Load image embedding from a URL.
        """
        image_embedding_io = ImageEmbeddingMediaIO()

        return image_embedding_io.load_base64("", data)

264

265
global_media_connector = MediaConnector()
266
267
"""The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector]
instance used by vLLM."""
268
269
270
271

fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
272
273


274
275
def encode_audio_base64(
    audio: np.ndarray,
276
    sampling_rate: float,
277
278
) -> str:
    """Encode audio as base64."""
279
280
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
281
282


283
284
285
286
287
288
289
290
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
291

292
293
    By default, the image is converted into RGB format before being encoded.
    """
294
295
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
296
297


298
def encode_video_base64(frames: npt.NDArray) -> str:
299
300
301
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
302
303


304
def merge_and_sort_multimodal_metadata(
305
306
    mm_positions: MultiModalPlaceholderDict,
    mm_hashes: Optional[MultiModalHashDict],
307
308
309
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
    """Given a MultiModalPlaceholderDict, merge all PlaceholderRange
    objects from all available modalities into a single list of 
310
    PlaceholderRange, sorted by their offset (starting index in the input
311
312
    sequence) in the ascending order.

313
    Optionally if a `MultiModalHashDict` is given, same operation will be
314
315
316
    applied to the object and the sorted list of hashes will be returned.
    
    Returns:
317
318
319
320
321
322
        list[str]: List of item modalities in order of their positions in the
        input sequence.
        list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
        mm_positions.
        Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
        None otherwise.
323
324
325
326
327
328
329
330
331
    """

    modalities = list(mm_positions.keys())

    assert len(modalities) > 0, "No modalities found in the mm_positions."

    # For single modality, placeholder ranges and hashes are already sorted
    # so we can return the list directly.
    if len(modalities) == 1:
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        modality = modalities[0]
        placeholder_list = list(mm_positions[modality])

        return [modality] * len(
            placeholder_list
        ), placeholder_list, None if not mm_hashes else mm_hashes[modality]

    # Create a list of (modality, placeholder, hash) tuples for all placeholders
    all_items = []
    for modality in modalities:
        placeholder_list = list(mm_positions[modality])
        hash_list: list[Optional[str]] = list(
            mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
                None
            ] * len(placeholder_list)

        for placeholder, hash_value in zip(placeholder_list, hash_list):
            all_items.append((modality, placeholder, hash_value))

    # Sort all items by offset
352
    all_items.sort(key=lambda x: x[1].offset)
353
354
355
356
357
358

    # Split into separate lists
    sorted_modalities = [item[0] for item in all_items]
    merged_placeholders = [item[1] for item in all_items]
    merged_hashes = [str(item[2])
                     for item in all_items] if mm_hashes is not None else None
359
360

    return sorted_modalities, merged_placeholders, merged_hashes
361
362
363


def group_mm_inputs_by_modality(
364
365
366
        mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
    """Group consecutive MultiModalKwargs from mm_inputs with the same modality
    together into the same list for batching purpose. For MultiModalKwargs with
367
368
369
370
371
372
    multiple modalities, put them into their own list.

    Args:
        mm_inputs: List of MultiModalKwargs.

    Returns:
373
374
375
        list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
        `MultiModalKwargs`, each inner list contains consecutive
        `MultiModalKwargs` with same modality.
376
377
378
379
    """
    if not mm_inputs:
        return []

380
    def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
381
382
383
384
385
        # If the input has multiple modalities, return a id as the unique key
        # for the mm_input input.
        if len(mm_input.modalities) > 1:
            return id(mm_input)

386
387
388
389
390
391
392
        elif len(mm_input.modalities) == 1:
            return list(mm_input.modalities)[0]

        # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
        # this is used to make InternVL with legacy pipeline still work with v1.
        else:
            return ""
393
394
395
396

    return [
        list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
    ]
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428


def run_dp_sharded_vision_model(image_input: torch.Tensor,
                                vision_model: torch.nn.Module) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function 
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.

    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
    pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
    image_input_per_rank = image_input_padded[rank *
                                              num_chunks_per_rank:(rank + 1) *
                                              num_chunks_per_rank, ...]

    vision_embeddings = vision_model(image_input_per_rank)
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
                                                         dim=0)
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings