image.py 4.38 KB
Newer Older
1
import base64
2
from functools import lru_cache
3
4
from io import BytesIO
from pathlib import Path
5
from typing import TYPE_CHECKING, Any, Dict, Optional
6
7
8
9

import torch
from PIL import Image

10
from vllm.inputs.registry import InputContext
11
from vllm.logger import init_logger
12
from vllm.transformers_utils.processor import get_image_processor
13
from vllm.utils import is_list_of
14

15
from .base import MediaIO, MultiModalPlugin
16
from .inputs import ImageItem, ModalityData, MultiModalKwargs
17

18
19
20
if TYPE_CHECKING:
    from vllm.config import ModelConfig

21
22
logger = init_logger(__name__)

23
cached_get_image_processor = lru_cache(get_image_processor)
24
25


26
class ImagePlugin(MultiModalPlugin):
27
    """Plugin for image data."""
28

29
30
    def get_data_key(self) -> str:
        return "image"
31

32
33
    def _get_hf_image_processor(
        self,
34
        model_config: "ModelConfig",
35
36
37
38
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
    ):
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}
39
        return cached_get_image_processor(
40
            model_config.model,
41
42
            trust_remote_code=model_config.trust_remote_code,
            **mm_processor_kwargs)
43

44
45
46
    def _default_input_mapper(
        self,
        ctx: InputContext,
47
        data: ModalityData[ImageItem],
48
        **mm_processor_kwargs,
49
    ) -> MultiModalKwargs:
50
        model_config = ctx.model_config
51

52
        # PIL image
53
        if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
54
55
56
57
            image_processor = self._get_hf_image_processor(
                model_config,
                mm_processor_kwargs,
            )
58

59
            if image_processor is None:
60
                raise RuntimeError("No HuggingFace processor is available "
61
62
                                   "to process the image object")
            try:
63
64
65
66
67
                # NOTE: It may make sense to forward the mm_processor_kwargs
                # here too. For now, to keep it simple, we only allow it be
                # used for the initialization call though, just in case the
                # signatures of the preprocessor initializer don't match
                # preprocess()
68
69
70
                batch_data = image_processor \
                    .preprocess(data, return_tensors="pt") \
                    .data
71
            except Exception:
72
73
74
75
76
77
                logger.error(
                    "Failed to process image (%s) with the default mapper. "
                    "This is most likely an edge-case with this model's image "
                    "processor in transformers (type: %s), and not vLLM.",
                    data,
                    type(image_processor).__name__)
78
79
                raise

80
            return MultiModalKwargs(batch_data)
81
82

        # Image embedding
83
        elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
84
            return MultiModalKwargs({"image_embeds": data})
85
86

        raise TypeError(f"Invalid image type: {type(data)}")
87
88
89

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 3000
90
91
92
93
94
95
96
97
98
99
100
101


def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


class ImageMediaIO(MediaIO[Image.Image]):

    def __init__(self, *, image_mode: str = "RGB") -> None:
        super().__init__()

        self.image_mode = image_mode

    def load_bytes(self, data: bytes) -> Image.Image:
        image = Image.open(BytesIO(data))
        image.load()
        return image.convert(self.image_mode)

    def load_base64(self, media_type: str, data: str) -> Image.Image:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> Image.Image:
        image = Image.open(filepath)
        image.load()
        return image.convert(self.image_mode)

    def encode_base64(
        self,
        media: Image.Image,
        *,
        image_format: str = "JPEG",
    ) -> str:
        image = media

        with BytesIO() as buffer:
            image = image.convert(self.image_mode)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return base64.b64encode(data).decode('utf-8')