image.py 4.93 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import base64
from io import BytesIO
from pathlib import Path
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10

import torch
from PIL import Image

11
from vllm.inputs.registry import InputContext
12
from vllm.logger import init_logger
13
from vllm.transformers_utils.processor import cached_get_image_processor
14
from vllm.utils import is_list_of
15

16
from .base import MediaIO, MultiModalPlugin
17
from .inputs import ImageItem, ModalityData, MultiModalKwargs
18

19
20
21
if TYPE_CHECKING:
    from vllm.config import ModelConfig

22
23
24
logger = init_logger(__name__)


25
class ImagePlugin(MultiModalPlugin):
26
    """Plugin for image data."""
27

28
29
    def get_data_key(self) -> str:
        return "image"
30

31
32
    def _get_hf_image_processor(
        self,
33
        model_config: "ModelConfig",
34
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
35
36
37
    ):
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}
38
        return cached_get_image_processor(
39
            model_config.model,
40
41
            trust_remote_code=model_config.trust_remote_code,
            **mm_processor_kwargs)
42

43
44
45
    def _default_input_mapper(
        self,
        ctx: InputContext,
46
        data: ModalityData[ImageItem],
47
        **mm_processor_kwargs,
48
    ) -> MultiModalKwargs:
49
        model_config = ctx.model_config
50

51
        # PIL image
52
        if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
53
54
55
56
            image_processor = self._get_hf_image_processor(
                model_config,
                mm_processor_kwargs,
            )
57

58
            if image_processor is None:
59
                raise RuntimeError("No HuggingFace processor is available "
60
61
                                   "to process the image object")
            try:
62
63
64
65
66
                # NOTE: It may make sense to forward the mm_processor_kwargs
                # here too. For now, to keep it simple, we only allow it be
                # used for the initialization call though, just in case the
                # signatures of the preprocessor initializer don't match
                # preprocess()
67
68
69
                batch_data = image_processor \
                    .preprocess(data, return_tensors="pt") \
                    .data
70
            except Exception:
71
72
73
74
75
76
                logger.error(
                    "Failed to process image (%s) with the default mapper. "
                    "This is most likely an edge-case with this model's image "
                    "processor in transformers (type: %s), and not vLLM.",
                    data,
                    type(image_processor).__name__)
77
78
                raise

79
            return MultiModalKwargs(batch_data)
80
81

        # Image embedding
82
        elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
83
            return MultiModalKwargs({"image_embeds": data})
84
85

        raise TypeError(f"Invalid image type: {type(data)}")
86
87
88

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 3000
89
90
91
92
93
94
95
96
97
98
99
100


def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


class ImageMediaIO(MediaIO[Image.Image]):

    def __init__(self, *, image_mode: str = "RGB") -> None:
        super().__init__()

        self.image_mode = image_mode

    def load_bytes(self, data: bytes) -> Image.Image:
        image = Image.open(BytesIO(data))
        image.load()
        return image.convert(self.image_mode)

    def load_base64(self, media_type: str, data: str) -> Image.Image:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> Image.Image:
        image = Image.open(filepath)
        image.load()
        return image.convert(self.image_mode)

    def encode_base64(
        self,
        media: Image.Image,
        *,
        image_format: str = "JPEG",
    ) -> str:
        image = media

        with BytesIO() as buffer:
            image = image.convert(self.image_mode)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return base64.b64encode(data).decode('utf-8')
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151


class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

    def __init__(self) -> None:
        super().__init__()

    def load_bytes(self, data: bytes) -> torch.Tensor:
        buffer = BytesIO(data)
        return torch.load(buffer, weights_only=True)

    def load_base64(self, media_type: str, data: str) -> torch.Tensor:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> torch.Tensor:
cyyever's avatar
cyyever committed
152
        return torch.load(filepath, weights_only=True)
153
154
155

    def encode_base64(self, media: torch.Tensor) -> str:
        return base64.b64encode(media.numpy()).decode('utf-8')