utils.py 3.18 KB
Newer Older
1
2
import base64
from io import BytesIO
3
from typing import Union
4
5
6

from PIL import Image

7
from vllm.connections import global_http_connection
8
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
9
from vllm.multimodal.base import MultiModalDataDict
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def _load_image_from_bytes(b: bytes):
    image = Image.open(BytesIO(b))
    image.load()
    return image


def _load_image_from_data_url(image_url: str):
    # Only split once and assume the second part is the base64 encoded image
    _, image_base64 = image_url.split(",", 1)
    return load_image_from_base64(image_base64)


24
25
26
27
28
29
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
    """
    Load a PIL image from a HTTP or base64 data URL.

    By default, the image is converted into RGB format.
    """
30
    if image_url.startswith('http'):
31
32
        image_raw = global_http_connection.get_bytes(
            image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
33
34
35
36
37
38
39
40
        image = _load_image_from_bytes(image_raw)

    elif image_url.startswith('data:image'):
        image = _load_image_from_data_url(image_url)
    else:
        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
                         "with either 'data:image' or 'http'.")

41
    return image.convert(image_mode)
42
43


44
45
46
47
48
async def async_fetch_image(image_url: str,
                            *,
                            image_mode: str = "RGB") -> Image.Image:
    """
    Asynchronously load a PIL image from a HTTP or base64 data URL.
49

50
51
52
53
54
55
    By default, the image is converted into RGB format.
    """
    if image_url.startswith('http'):
        image_raw = await global_http_connection.async_get_bytes(
            image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
        image = _load_image_from_bytes(image_raw)
56

57
58
59
60
61
    elif image_url.startswith('data:image'):
        image = _load_image_from_data_url(image_url)
    else:
        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
                         "with either 'data:image' or 'http'.")
62

63
    return image.convert(image_mode)
64
65


66
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
67
    image = await async_fetch_image(image_url)
68
69
70
    return {"image": image}


71
72
73
74
75
76
77
78
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
79

80
81
    By default, the image is converted into RGB format before being encoded.
    """
82
    buffered = BytesIO()
83
    image = image.convert(image_mode)
84
85
86
87
88
89
    image.save(buffered, format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
    """Load image from base64 format."""
90
    return _load_image_from_bytes(base64.b64decode(image))
91
92


93
94
95
def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
96
97
98
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
99
100
101
102
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image