image.py 4.42 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import base64
4
from functools import lru_cache
5
6
from io import BytesIO
from pathlib import Path
7
from typing import TYPE_CHECKING, Any, Dict, Optional
8
9
10
11

import torch
from PIL import Image

12
from vllm.inputs.registry import InputContext
13
from vllm.logger import init_logger
14
from vllm.transformers_utils.processor import get_image_processor
15
from vllm.utils import is_list_of
16

17
from .base import MediaIO, MultiModalPlugin
18
from .inputs import ImageItem, ModalityData, MultiModalKwargs
19

20
21
22
if TYPE_CHECKING:
    from vllm.config import ModelConfig

23
24
logger = init_logger(__name__)

25
cached_get_image_processor = lru_cache(get_image_processor)
26
27


28
class ImagePlugin(MultiModalPlugin):
29
    """Plugin for image data."""
30

31
32
    def get_data_key(self) -> str:
        return "image"
33

34
35
    def _get_hf_image_processor(
        self,
36
        model_config: "ModelConfig",
37
38
39
40
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
    ):
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}
41
        return cached_get_image_processor(
42
            model_config.model,
43
44
            trust_remote_code=model_config.trust_remote_code,
            **mm_processor_kwargs)
45

46
47
48
    def _default_input_mapper(
        self,
        ctx: InputContext,
49
        data: ModalityData[ImageItem],
50
        **mm_processor_kwargs,
51
    ) -> MultiModalKwargs:
52
        model_config = ctx.model_config
53

54
        # PIL image
55
        if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
56
57
58
59
            image_processor = self._get_hf_image_processor(
                model_config,
                mm_processor_kwargs,
            )
60

61
            if image_processor is None:
62
                raise RuntimeError("No HuggingFace processor is available "
63
64
                                   "to process the image object")
            try:
65
66
67
68
69
                # NOTE: It may make sense to forward the mm_processor_kwargs
                # here too. For now, to keep it simple, we only allow it be
                # used for the initialization call though, just in case the
                # signatures of the preprocessor initializer don't match
                # preprocess()
70
71
72
                batch_data = image_processor \
                    .preprocess(data, return_tensors="pt") \
                    .data
73
            except Exception:
74
75
76
77
78
79
                logger.error(
                    "Failed to process image (%s) with the default mapper. "
                    "This is most likely an edge-case with this model's image "
                    "processor in transformers (type: %s), and not vLLM.",
                    data,
                    type(image_processor).__name__)
80
81
                raise

82
            return MultiModalKwargs(batch_data)
83
84

        # Image embedding
85
        elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
86
            return MultiModalKwargs({"image_embeds": data})
87
88

        raise TypeError(f"Invalid image type: {type(data)}")
89
90
91

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 3000
92
93
94
95
96
97
98
99
100
101
102
103


def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


class ImageMediaIO(MediaIO[Image.Image]):

    def __init__(self, *, image_mode: str = "RGB") -> None:
        super().__init__()

        self.image_mode = image_mode

    def load_bytes(self, data: bytes) -> Image.Image:
        image = Image.open(BytesIO(data))
        image.load()
        return image.convert(self.image_mode)

    def load_base64(self, media_type: str, data: str) -> Image.Image:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> Image.Image:
        image = Image.open(filepath)
        image.load()
        return image.convert(self.image_mode)

    def encode_base64(
        self,
        media: Image.Image,
        *,
        image_format: str = "JPEG",
    ) -> str:
        image = media

        with BytesIO() as buffer:
            image = image.convert(self.image_mode)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return base64.b64encode(data).decode('utf-8')