utils.py 4.32 KB
Newer Older
1
2
3
import base64
from io import BytesIO
from typing import Optional, Union
4
from urllib.parse import urlparse
5
6

import aiohttp
7
import requests
8
9
10
from PIL import Image

from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
11
from vllm.multimodal.base import MultiModalDataDict
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.version import __version__ as VLLM_VERSION


def _validate_remote_url(url: str, *, name: str):
    parsed_url = urlparse(url)
    if parsed_url.scheme not in ["http", "https"]:
        raise ValueError(f"Invalid '{name}': A valid '{name}' "
                         "must have scheme 'http' or 'https'.")


def _get_request_headers():
    return {"User-Agent": f"vLLM/{VLLM_VERSION}"}


def _load_image_from_bytes(b: bytes):
    image = Image.open(BytesIO(b))
    image.load()
    return image


def _load_image_from_data_url(image_url: str):
    # Only split once and assume the second part is the base64 encoded image
    _, image_base64 = image_url.split(",", 1)
    return load_image_from_base64(image_base64)


38
39
40
41
42
43
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
    """
    Load a PIL image from a HTTP or base64 data URL.

    By default, the image is converted into RGB format.
    """
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    if image_url.startswith('http'):
        _validate_remote_url(image_url, name="image_url")

        headers = _get_request_headers()

        with requests.get(url=image_url, headers=headers) as response:
            response.raise_for_status()
            image_raw = response.content
        image = _load_image_from_bytes(image_raw)

    elif image_url.startswith('data:image'):
        image = _load_image_from_data_url(image_url)
    else:
        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
                         "with either 'data:image' or 'http'.")

60
    return image.convert(image_mode)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


class ImageFetchAiohttp:
    aiohttp_client: Optional[aiohttp.ClientSession] = None

    @classmethod
    def get_aiohttp_client(cls) -> aiohttp.ClientSession:
        if cls.aiohttp_client is None:
            timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
            connector = aiohttp.TCPConnector()
            cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
                                                       connector=connector)

        return cls.aiohttp_client

    @classmethod
77
78
79
80
81
82
83
84
85
86
87
    async def fetch_image(
        cls,
        image_url: str,
        *,
        image_mode: str = "RGB",
    ) -> Image.Image:
        """
        Asynchronously load a PIL image from a HTTP or base64 data URL.

        By default, the image is converted into RGB format.
        """
88
89

        if image_url.startswith('http'):
90
            _validate_remote_url(image_url, name="image_url")
91
92

            client = cls.get_aiohttp_client()
93
            headers = _get_request_headers()
94
95
96
97

            async with client.get(url=image_url, headers=headers) as response:
                response.raise_for_status()
                image_raw = await response.read()
98
            image = _load_image_from_bytes(image_raw)
99
100

        elif image_url.startswith('data:image'):
101
            image = _load_image_from_data_url(image_url)
102
        else:
103
104
105
            raise ValueError(
                "Invalid 'image_url': A valid 'image_url' must start "
                "with either 'data:image' or 'http'.")
106

107
        return image.convert(image_mode)
108
109


110
111
112
113
114
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
    image = await ImageFetchAiohttp.fetch_image(image_url)
    return {"image": image}


115
116
117
118
119
120
121
122
def encode_image_base64(
    image: Image.Image,
    *,
    image_mode: str = "RGB",
    format: str = "JPEG",
) -> str:
    """
    Encode a pillow image to base64 format.
123

124
125
    By default, the image is converted into RGB format before being encoded.
    """
126
    buffered = BytesIO()
127
    image = image.convert(image_mode)
128
129
130
131
132
133
    image.save(buffered, format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
    """Load image from base64 format."""
134
    return _load_image_from_bytes(base64.b64decode(image))
135
136


137
138
139
140
141
def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
    return image.resize((new_width, new_height))