utils.py 13.8 KB
Newer Older
1
from functools import lru_cache
2
3
4
from pathlib import Path
from typing import Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
5

6
import numpy as np
7
import numpy.typing as npt
8
import torch
9
10
from PIL import Image

11
import vllm.envs as envs
12
from vllm.connections import HTTPConnection, global_http_connection
13
14
15
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer

16
17
18
19
20
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
21

22
23
24
logger = init_logger(__name__)

cached_get_tokenizer = lru_cache(get_tokenizer)
25

26
_M = TypeVar("_M")
27

28

29
class MediaConnector:
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    def __init__(
        self,
        connection: HTTPConnection = global_http_connection,
        *,
        allowed_local_media_path: str = "",
    ) -> None:
        super().__init__()

        self.connection = connection

        if allowed_local_media_path:
            allowed_local_media_path_ = Path(allowed_local_media_path)

            if not allowed_local_media_path_.exists():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} does not exist.")
            if not allowed_local_media_path_.is_dir():
                raise ValueError(
                    "Invalid `--allowed-local-media-path`: The path "
                    f"{allowed_local_media_path_} must be a directory.")
        else:
            allowed_local_media_path_ = None

        self.allowed_local_media_path = allowed_local_media_path_

    def _load_data_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        data_spec, data = url_spec.path.split(",", 1)
        media_type, data_type = data_spec.split(";", 1)

        if data_type != "base64":
            msg = "Only base64 data URLs are supported for now."
            raise NotImplementedError(msg)

        return media_io.load_base64(media_type, data)

    def _load_file_url(
        self,
        url_spec: ParseResult,
        media_io: MediaIO[_M],
    ) -> _M:
        allowed_local_media_path = self.allowed_local_media_path
        if allowed_local_media_path is None:
            raise RuntimeError("Cannot load local files without "
                               "`--allowed-local-media-path`.")

        filepath = Path(url_spec.path)
        if allowed_local_media_path not in filepath.resolve().parents:
83
            raise ValueError(
84
85
                f"The file path {filepath} must be a subpath "
                f"of `--allowed-local-media-path` {allowed_local_media_path}.")
86

87
        return media_io.load_file(filepath)
88

89
90
91
92
93
94
95
96
    def load_from_url(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
97

98
99
100
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = connection.get_bytes(url, timeout=fetch_timeout)
101

102
            return media_io.load_bytes(data)
103

104
105
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
106

107
108
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
109

110
111
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
112

113
114
115
116
117
118
119
120
    async def load_from_url_async(
        self,
        url: str,
        media_io: MediaIO[_M],
        *,
        fetch_timeout: Optional[int] = None,
    ) -> _M:
        url_spec = urlparse(url)
121

122
123
124
        if url_spec.scheme.startswith("http"):
            connection = self.connection
            data = await connection.async_get_bytes(url, timeout=fetch_timeout)
125

126
            return media_io.load_bytes(data)
127

128
129
        if url_spec.scheme == "data":
            return self._load_data_url(url_spec, media_io)
130

131
132
        if url_spec.scheme == "file":
            return self._load_file_url(url_spec, media_io)
133

134
135
        msg = "The URL must be either a HTTP, data or file URL."
        raise ValueError(msg)
136

137
138
139
140
141
142
143
144
    def fetch_audio(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Load audio from a URL.
        """
        audio_io = AudioMediaIO()
145

146
        return self.load_from_url(
147
            audio_url,
148
149
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
150
        )
151

152
153
154
155
156
157
158
159
    async def fetch_audio_async(
        self,
        audio_url: str,
    ) -> tuple[np.ndarray, Union[int, float]]:
        """
        Asynchronously fetch audio from a URL.
        """
        audio_io = AudioMediaIO()
160

161
        return await self.load_from_url_async(
162
            audio_url,
163
164
            audio_io,
            fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
165
        )
166

167
168
169
170
171
172
173
174
    def fetch_image(
        self,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Load a PIL image from a HTTP or base64 data URL.
175

176
177
178
        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
179

180
181
182
183
184
        return self.load_from_url(
            image_url,
            image_io,
            fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
        )
185

186
187
    async def fetch_image_async(
        self,
188
189
        image_url: str,
        *,
190
191
192
193
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Asynchronously load a PIL image from a HTTP or base64 data URL.
194

195
196
197
        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
198

199
200
201
202
203
        return await self.load_from_url_async(
            image_url,
            image_io,
            fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
        )
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def fetch_video(
        self,
        video_url: str,
        *,
        image_mode: str = "RGB",
        num_frames: int = 32,
    ) -> npt.NDArray:
        """
        Load video from a HTTP or base64 data URL.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
        video_io = VideoMediaIO(image_io, num_frames=num_frames)

        return self.load_from_url(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
223

224
225
226
    async def fetch_video_async(
        self,
        video_url: str,
227
        *,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        image_mode: str = "RGB",
        num_frames: int = 32,
    ) -> npt.NDArray:
        """
        Asynchronously load video from a HTTP or base64 data URL.

        By default, the image is converted into RGB format.
        """
        image_io = ImageMediaIO(image_mode=image_mode)
        video_io = VideoMediaIO(image_io, num_frames=num_frames)

        return await self.load_from_url_async(
            video_url,
            video_io,
            fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
        )
244
245


246
247
248
249
250
251
global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""

fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
252
253


254
255
256
257
258
def encode_audio_base64(
    audio: np.ndarray,
    sampling_rate: int,
) -> str:
    """Encode audio as base64."""
259
260
    audio_io = AudioMediaIO()
    return audio_io.encode_base64((audio, sampling_rate))
261
262


263
264
265
266
267
268
269
270
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
271

272
273
    By default, the image is converted into RGB format before being encoded.
    """
274
275
    image_io = ImageMediaIO(image_mode=image_mode)
    return image_io.encode_base64(image, image_format=format)
276
277


278
def encode_video_base64(frames: npt.NDArray) -> str:
279
280
281
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io)
    return video_io.encode_base64(frames)
282
283


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def resolve_visual_encoder_outputs(
    encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
    feature_sample_layers: Optional[list[int]],
    post_layer_norm: Optional[torch.nn.LayerNorm],
    max_possible_layers: int,
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        feature_sample_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
        post_layer_norm: Post norm to apply to the output of the encoder.
        max_possible_layers: Total layers in the fully loaded visual encoder.

    """
    if feature_sample_layers is None:
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
        return encoder_outputs

    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
    # NOTE: this assumes that encoder_outputs contains a list
    # of hidden states in the same order as the encoder layers
    # that produced them.
    offset = max_possible_layers - len(encoder_outputs)
    hs_pool = [
        encoder_outputs[layer_idx]
        if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
        for layer_idx in feature_sample_layers
    ]

    # Apply post-norm on the final hidden state if we are using it
    uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
    if post_layer_norm is not None and uses_last_layer:
        hs_pool[-1] = post_layer_norm(encoder_outputs)
    return torch.cat(hs_pool, dim=-1)


327
328
329
330
331
332
333
334
335
336
# Utilities for input processors
_T = TypeVar("_T", str, int)


def repeat_and_pad_token(
    token: _T,
    *,
    repeat_count: int = 1,
    pad_token_left: Optional[_T] = None,
    pad_token_right: Optional[_T] = None,
337
) -> list[_T]:
338
339
340
341
342
343
344
345
346
347
348
349
    replacement = [token] * repeat_count
    if pad_token_left is not None:
        replacement = [pad_token_left] + replacement
    if pad_token_right is not None:
        replacement = replacement + [pad_token_right]

    return replacement


def repeat_and_pad_placeholder_tokens(
    tokenizer: AnyTokenizer,
    prompt: Optional[str],
350
    prompt_token_ids: list[int],
351
352
    *,
    placeholder_token_id: int,
353
    repeat_count: Union[int, list[int]],
354
355
    pad_token_left: Optional[int] = None,
    pad_token_right: Optional[int] = None,
356
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
357
358
359
    if isinstance(repeat_count, int):
        repeat_count = [repeat_count]

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    if prompt is None:
        new_prompt = None
    else:
        placeholder_token_str = tokenizer.decode(placeholder_token_id)
        pad_token_str_left = (None if pad_token_left is None else
                              tokenizer.decode(pad_token_left))
        pad_token_str_right = (None if pad_token_right is None else
                               tokenizer.decode(pad_token_right))

        placeholder_token_count = prompt.count(placeholder_token_str)
        # This is an arbitrary number to distinguish between the two cases
        if placeholder_token_count > 16:
            logger.warning(
                "Please follow the prompt format that is "
                "documented on HuggingFace which does not involve "
                "repeating %s tokens.", placeholder_token_str)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        if placeholder_token_count < len(repeat_count):
            logger.warning(
                "The number of multi-modal placeholder tokens in the prompt "
                "is less than the number of multi-modal inputs. Extra "
                "placeholder tokens will be treated as plain text")
            repeat_count = repeat_count[:placeholder_token_count]

        prompt_parts = prompt.split(placeholder_token_str,
                                    maxsplit=len(repeat_count))
        new_prompt = ""
        for i, repeat_count_item in enumerate(repeat_count):
            replacement_str = "".join(
                repeat_and_pad_token(
                    placeholder_token_str,
                    repeat_count=repeat_count_item,
                    pad_token_left=pad_token_str_left,
                    pad_token_right=pad_token_str_right,
                ))
            # The image tokens are removed to be consistent with HuggingFace
            new_prompt += prompt_parts[i] + replacement_str
        new_prompt += prompt_parts[-1]
397

398
399
    new_token_ids = list[int]()
    placeholder_ranges = list[PlaceholderRange]()
400
    placeholder_token_idx = 0
401
402
    for i, token in enumerate(prompt_token_ids):
        if token == placeholder_token_id:
403
            curr_repeat_count = repeat_count[placeholder_token_idx]
404
405
            replacement_ids = repeat_and_pad_token(
                placeholder_token_id,
406
                repeat_count=curr_repeat_count,
407
408
409
                pad_token_left=pad_token_left,
                pad_token_right=pad_token_right,
            )
410
411
412
            offset = len(new_token_ids)
            if pad_token_left is not None:
                offset += 1
413
            placeholder_ranges.append({
414
415
                "offset": offset,
                "length": curr_repeat_count,
416
            })
417
            new_token_ids.extend(replacement_ids)
418
            placeholder_token_idx += 1
419

420
421
422
423
            # No need to further scan the list since we replaced all tokens
            if placeholder_token_idx >= len(repeat_count):
                new_token_ids.extend(prompt_token_ids[i + 1:])
                break
424
425
426
        else:
            new_token_ids.append(token)

427
428
429
    return new_prompt, new_token_ids, placeholder_ranges


430
431
432
def consecutive_placeholder_ranges(
        num_items: int,
        item_size: int,
433
        initial_offset: int = 0) -> list[PlaceholderRange]:
434
435
436
    """Returns a list of consecutive PlaceholderRanges of a fixed size"""

    return [
437
438
        PlaceholderRange(offset=initial_offset + i * item_size,
                         length=item_size) for i in range(num_items)
439
    ]