_image.py 8.94 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, List, Optional, Tuple, Union
Philip Meier's avatar
Philip Meier committed
4

5
import PIL.Image
Philip Meier's avatar
Philip Meier committed
6
import torch
7
from torchvision.transforms.functional import InterpolationMode
Philip Meier's avatar
Philip Meier committed
8

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Image(Datapoint):
13
    @classmethod
14
    def _wrap(cls, tensor: torch.Tensor) -> Image:
15
16
17
        image = tensor.as_subclass(cls)
        return image

18
19
20
21
    def __new__(
        cls,
        data: Any,
        *,
22
23
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
24
        requires_grad: Optional[bool] = None,
25
    ) -> Image:
26
        if isinstance(data, PIL.Image.Image):
27
            from torchvision.transforms.v2 import functional as F
28
29
30

            data = F.pil_to_tensor(data)

31
32
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        if tensor.ndim < 2:
33
            raise ValueError
34
35
        elif tensor.ndim == 2:
            tensor = tensor.unsqueeze(0)
36

37
        return cls._wrap(tensor)
38

39
    @classmethod
40
41
    def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
        return cls._wrap(tensor)
42

43
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
44
        return self._make_repr()
45

46
    @property
47
    def spatial_size(self) -> Tuple[int, int]:
48
        return tuple(self.shape[-2:])  # type: ignore[return-value]
49
50
51
52

    @property
    def num_channels(self) -> int:
        return self.shape[-3]
Philip Meier's avatar
Philip Meier committed
53

54
    def horizontal_flip(self) -> Image:
55
        output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
56
        return Image.wrap_like(self, output)
57
58

    def vertical_flip(self) -> Image:
59
        output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
60
        return Image.wrap_like(self, output)
61
62
63
64

    def resize(  # type: ignore[override]
        self,
        size: List[int],
65
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
66
        max_size: Optional[int] = None,
67
        antialias: Optional[Union[str, bool]] = "warn",
68
    ) -> Image:
69
        output = self._F.resize_image_tensor(
70
            self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
71
        )
72
        return Image.wrap_like(self, output)
73
74

    def crop(self, top: int, left: int, height: int, width: int) -> Image:
75
        output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
76
        return Image.wrap_like(self, output)
77
78

    def center_crop(self, output_size: List[int]) -> Image:
79
        output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
80
        return Image.wrap_like(self, output)
81
82
83
84
85
86
87
88

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
89
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
90
        antialias: Optional[Union[str, bool]] = "warn",
91
    ) -> Image:
92
        output = self._F.resized_crop_image_tensor(
93
94
95
96
97
98
99
100
            self.as_subclass(torch.Tensor),
            top,
            left,
            height,
            width,
            size=list(size),
            interpolation=interpolation,
            antialias=antialias,
101
        )
102
        return Image.wrap_like(self, output)
103
104

    def pad(
105
        self,
106
107
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
108
        padding_mode: str = "constant",
109
    ) -> Image:
110
        output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
111
        return Image.wrap_like(self, output)
112
113
114
115

    def rotate(
        self,
        angle: float,
116
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
117
118
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
119
        fill: _FillTypeJIT = None,
120
    ) -> Image:
121
122
        output = self._F.rotate_image_tensor(
            self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
123
        )
124
        return Image.wrap_like(self, output)
125
126
127

    def affine(
        self,
128
        angle: Union[int, float],
129
130
131
        translate: List[float],
        scale: float,
        shear: List[float],
132
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
133
        fill: _FillTypeJIT = None,
134
135
        center: Optional[List[float]] = None,
    ) -> Image:
136
137
        output = self._F.affine_image_tensor(
            self.as_subclass(torch.Tensor),
138
139
140
141
142
143
144
145
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
146
        return Image.wrap_like(self, output)
147
148
149

    def perspective(
        self,
150
151
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
152
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
153
        fill: _FillTypeJIT = None,
154
        coefficients: Optional[List[float]] = None,
155
    ) -> Image:
156
        output = self._F.perspective_image_tensor(
157
158
159
160
161
162
            self.as_subclass(torch.Tensor),
            startpoints,
            endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
163
        )
164
        return Image.wrap_like(self, output)
165

166
167
168
    def elastic(
        self,
        displacement: torch.Tensor,
169
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
170
        fill: _FillTypeJIT = None,
171
    ) -> Image:
172
173
174
        output = self._F.elastic_image_tensor(
            self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
        )
175
        return Image.wrap_like(self, output)
176

177
    def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
178
179
180
181
182
        output = self._F.rgb_to_grayscale_image_tensor(
            self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
        )
        return Image.wrap_like(self, output)

183
    def adjust_brightness(self, brightness_factor: float) -> Image:
184
185
186
        output = self._F.adjust_brightness_image_tensor(
            self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
        )
187
        return Image.wrap_like(self, output)
188
189

    def adjust_saturation(self, saturation_factor: float) -> Image:
190
191
192
        output = self._F.adjust_saturation_image_tensor(
            self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
        )
193
        return Image.wrap_like(self, output)
194
195

    def adjust_contrast(self, contrast_factor: float) -> Image:
196
        output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
197
        return Image.wrap_like(self, output)
198
199

    def adjust_sharpness(self, sharpness_factor: float) -> Image:
200
201
202
        output = self._F.adjust_sharpness_image_tensor(
            self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
        )
203
        return Image.wrap_like(self, output)
204
205

    def adjust_hue(self, hue_factor: float) -> Image:
206
        output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
207
        return Image.wrap_like(self, output)
208
209

    def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
210
        output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
211
        return Image.wrap_like(self, output)
212
213

    def posterize(self, bits: int) -> Image:
214
        output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
215
        return Image.wrap_like(self, output)
216
217

    def solarize(self, threshold: float) -> Image:
218
        output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
219
        return Image.wrap_like(self, output)
220
221

    def autocontrast(self) -> Image:
222
        output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
223
        return Image.wrap_like(self, output)
224
225

    def equalize(self) -> Image:
226
        output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
227
        return Image.wrap_like(self, output)
228
229

    def invert(self) -> Image:
230
        output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
231
        return Image.wrap_like(self, output)
232
233

    def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
234
235
236
        output = self._F.gaussian_blur_image_tensor(
            self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
        )
237
        return Image.wrap_like(self, output)
238

239
240
241
242
    def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
        output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
        return Image.wrap_like(self, output)

243

Philip Meier's avatar
Philip Meier committed
244
245
246
247
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor