utils.py 8.68 KB
Newer Older
1
import base64
2
from functools import lru_cache
3
from io import BytesIO
4
from typing import List, Optional, Tuple, TypeVar, Union
5

6
7
8
import librosa
import numpy as np
import soundfile
9
10
from PIL import Image

11
from vllm.connections import global_http_connection
12
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
13
from vllm.logger import init_logger
14
from vllm.multimodal.base import MultiModalDataDict
15
16
17
18
19
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer

logger = init_logger(__name__)

cached_get_tokenizer = lru_cache(get_tokenizer)
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def _load_image_from_bytes(b: bytes):
    image = Image.open(BytesIO(b))
    image.load()
    return image


def _load_image_from_data_url(image_url: str):
    # Only split once and assume the second part is the base64 encoded image
    _, image_base64 = image_url.split(",", 1)
    return load_image_from_base64(image_base64)


34
35
36
37
38
39
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
    """
    Load a PIL image from a HTTP or base64 data URL.

    By default, the image is converted into RGB format.
    """
40
    if image_url.startswith('http'):
41
42
        image_raw = global_http_connection.get_bytes(
            image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
43
44
45
46
47
48
49
50
        image = _load_image_from_bytes(image_raw)

    elif image_url.startswith('data:image'):
        image = _load_image_from_data_url(image_url)
    else:
        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
                         "with either 'data:image' or 'http'.")

51
    return image.convert(image_mode)
52
53


54
55
56
57
58
async def async_fetch_image(image_url: str,
                            *,
                            image_mode: str = "RGB") -> Image.Image:
    """
    Asynchronously load a PIL image from a HTTP or base64 data URL.
59

60
61
62
63
64
65
    By default, the image is converted into RGB format.
    """
    if image_url.startswith('http'):
        image_raw = await global_http_connection.async_get_bytes(
            image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
        image = _load_image_from_bytes(image_raw)
66

67
68
69
70
71
    elif image_url.startswith('data:image'):
        image = _load_image_from_data_url(image_url)
    else:
        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
                         "with either 'data:image' or 'http'.")
72

73
    return image.convert(image_mode)
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
    """
    Load audio from a URL.
    """
    if audio_url.startswith("http"):
        audio_bytes = global_http_connection.get_bytes(
            audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
    elif audio_url.startswith("data:audio"):
        _, audio_base64 = audio_url.split(",", 1)
        audio_bytes = base64.b64decode(audio_base64)
    else:
        raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
                         "with either 'data:audio' or 'http'.")

    return librosa.load(BytesIO(audio_bytes), sr=None)


async def async_fetch_audio(
        audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
    """
    Asynchronously fetch audio from a URL.
    """
    if audio_url.startswith("http"):
        audio_bytes = await global_http_connection.async_get_bytes(
            audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
    elif audio_url.startswith("data:audio"):
        _, audio_base64 = audio_url.split(",", 1)
        audio_bytes = base64.b64decode(audio_base64)
    else:
        raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
                         "with either 'data:audio' or 'http'.")

    return librosa.load(BytesIO(audio_bytes), sr=None)


async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
    audio, sr = await async_fetch_audio(audio_url)
    return {"audio": (audio, sr)}


116
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
117
    image = await async_fetch_image(image_url)
118
119
120
    return {"image": image}


121
122
123
124
125
126
127
128
129
130
131
def encode_audio_base64(
    audio: np.ndarray,
    sampling_rate: int,
) -> str:
    """Encode audio as base64."""
    buffered = BytesIO()
    soundfile.write(buffered, audio, sampling_rate, format="WAV")

    return base64.b64encode(buffered.getvalue()).decode('utf-8')


132
133
134
135
136
137
138
139
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
140

141
142
    By default, the image is converted into RGB format before being encoded.
    """
143
    buffered = BytesIO()
144
    image = image.convert(image_mode)
145
146
147
148
149
150
    image.save(buffered, format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
    """Load image from base64 format."""
151
    return _load_image_from_bytes(base64.b64decode(image))
152
153


154
155
156
def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
157
158
159
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
160
161
162
163
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191


# Utilities for input processors
_T = TypeVar("_T", str, int)


def repeat_and_pad_token(
    token: _T,
    *,
    repeat_count: int = 1,
    pad_token_left: Optional[_T] = None,
    pad_token_right: Optional[_T] = None,
) -> List[_T]:
    replacement = [token] * repeat_count
    if pad_token_left is not None:
        replacement = [pad_token_left] + replacement
    if pad_token_right is not None:
        replacement = replacement + [pad_token_right]

    return replacement


def repeat_and_pad_placeholder_tokens(
    tokenizer: AnyTokenizer,
    prompt: Optional[str],
    prompt_token_ids: List[int],
    *,
    placeholder_token_id: int,
192
    repeat_count: Union[int, List[int]],
193
194
195
    pad_token_left: Optional[int] = None,
    pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
196
197
198
    if isinstance(repeat_count, int):
        repeat_count = [repeat_count]

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    if prompt is None:
        new_prompt = None
    else:
        placeholder_token_str = tokenizer.decode(placeholder_token_id)
        pad_token_str_left = (None if pad_token_left is None else
                              tokenizer.decode(pad_token_left))
        pad_token_str_right = (None if pad_token_right is None else
                               tokenizer.decode(pad_token_right))

        placeholder_token_count = prompt.count(placeholder_token_str)
        # This is an arbitrary number to distinguish between the two cases
        if placeholder_token_count > 16:
            logger.warning(
                "Please follow the prompt format that is "
                "documented on HuggingFace which does not involve "
                "repeating %s tokens.", placeholder_token_str)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        if placeholder_token_count < len(repeat_count):
            logger.warning(
                "The number of multi-modal placeholder tokens in the prompt "
                "is less than the number of multi-modal inputs. Extra "
                "placeholder tokens will be treated as plain text")
            repeat_count = repeat_count[:placeholder_token_count]

        prompt_parts = prompt.split(placeholder_token_str,
                                    maxsplit=len(repeat_count))
        new_prompt = ""
        for i, repeat_count_item in enumerate(repeat_count):
            replacement_str = "".join(
                repeat_and_pad_token(
                    placeholder_token_str,
                    repeat_count=repeat_count_item,
                    pad_token_left=pad_token_str_left,
                    pad_token_right=pad_token_str_right,
                ))
            # The image tokens are removed to be consistent with HuggingFace
            new_prompt += prompt_parts[i] + replacement_str
        new_prompt += prompt_parts[-1]
236
237

    new_token_ids: List[int] = []
238
    placeholder_token_idx = 0
239
240
241
242
    for i, token in enumerate(prompt_token_ids):
        if token == placeholder_token_id:
            replacement_ids = repeat_and_pad_token(
                placeholder_token_id,
243
                repeat_count=repeat_count[placeholder_token_idx],
244
245
246
247
                pad_token_left=pad_token_left,
                pad_token_right=pad_token_right,
            )
            new_token_ids.extend(replacement_ids)
248
            placeholder_token_idx += 1
249

250
251
252
253
            # No need to further scan the list since we replaced all tokens
            if placeholder_token_idx >= len(repeat_count):
                new_token_ids.extend(prompt_token_ids[i + 1:])
                break
254
255
256
257
        else:
            new_token_ids.append(token)

    return new_prompt, new_token_ids