image.py 3.74 KB
Newer Older
1
2
from functools import lru_cache
from typing import Dict, Type, Union
3
4
5
6

import torch
from PIL import Image

7
8
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
9
from vllm.logger import init_logger
10
from vllm.transformers_utils.image_processor import get_image_processor
11
12
13
14
15

from .base import MultiModalData, MultiModalPlugin

logger = init_logger(__name__)

16
cached_get_image_processor = lru_cache(get_image_processor)
17
18
19
20
21
22


class ImagePixelData(MultiModalData):
    """
    The pixel data of an image. Can be one of:

23
    - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
24
      processor is available to the model.
25
    - :class:`torch.Tensor`: The raw pixel data which is passed to the model
26
27
28
29
30
31
32
33
34
35
      without additional pre-processing.
    """

    def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
        if isinstance(image, Image.Image):
            # So that this class can be created inside the Image context manager
            image.load()

        self.image = image

36
37
38
39
40
41
42
43
    def __repr__(self) -> str:
        image = self.image
        if isinstance(image, Image.Image):
            return f"{type(self).__name__}(image={image})"

        return (f"{type(self).__name__}(image=torch.Tensor(shape="
                f"{image.shape}, dtype={image.dtype}))")

44
45
46
47
48
49

class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):

    def get_data_type(self) -> Type[ImagePixelData]:
        return ImagePixelData

50
51
    def _get_hf_image_processor(self, model_config: ModelConfig):
        vlm_config = model_config.multimodal_config
52
53
54
55
56
57
58
59
60
        if vlm_config is None or vlm_config.image_processor is None:
            return None

        return cached_get_image_processor(
            vlm_config.image_processor,
            trust_remote_code=model_config.trust_remote_code,
            revision=vlm_config.image_processor_revision,
        )

61
62
63
    def _default_input_mapper(self, ctx: InputContext,
                              data: ImagePixelData) -> Dict[str, torch.Tensor]:
        model_config = ctx.model_config
64
65
66
        image = data.image

        if isinstance(image, Image.Image):
67
            image_processor = self._get_hf_image_processor(model_config)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            if image_processor is None:
                raise RuntimeError("No HuggingFace processor is available"
                                   "to process the image object")
            try:
                return image_processor.preprocess(image, return_tensors="pt") \
                    .to(model_config.dtype).data
            except Exception:
                logger.error("Failed to process image (%s)", image)
                raise
        elif isinstance(image, torch.Tensor):
            pixel_values = image.to(model_config.dtype)

            return {"pixel_values": pixel_values}

        raise TypeError(f"Invalid image type: {type(image)}")


class ImageFeatureData(MultiModalData):
    """
    The feature vector of an image, passed directly to the model.

    This should be the output of the vision tower.
    """

    def __init__(self, image_features: torch.Tensor) -> None:
        self.image_features = image_features

95
96
97
98
99
100
    def __repr__(self) -> str:
        image_features = self.image_features

        return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
                f"{image_features.shape}, dtype={image_features.dtype}))")

101
102
103
104
105
106

class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):

    def get_data_type(self) -> Type[ImageFeatureData]:
        return ImageFeatureData

107
108
109
110
    def _default_input_mapper(
            self, ctx: InputContext,
            data: ImageFeatureData) -> Dict[str, torch.Tensor]:
        model_config = ctx.model_config
111
112
113
        image_features = data.image_features.to(model_config.dtype)

        return {"image_features": image_features}