_image.py 9.53 KB
Newer Older
1
2
from __future__ import annotations

Philip Meier's avatar
Philip Meier committed
3
from typing import Any, List, Optional, Union
Philip Meier's avatar
Philip Meier committed
4

5
import PIL.Image
Philip Meier's avatar
Philip Meier committed
6
import torch
7
from torchvision.transforms.functional import InterpolationMode
Philip Meier's avatar
Philip Meier committed
8

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Image(Datapoint):
Philip Meier's avatar
Philip Meier committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    """[BETA] :class:`torch.Tensor` subclass for images.

    Args:
        data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
            well as PIL images.
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

26
    @classmethod
27
    def _wrap(cls, tensor: torch.Tensor) -> Image:
28
29
30
        image = tensor.as_subclass(cls)
        return image

31
32
33
34
    def __new__(
        cls,
        data: Any,
        *,
35
36
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
37
        requires_grad: Optional[bool] = None,
38
    ) -> Image:
39
        if isinstance(data, PIL.Image.Image):
40
            from torchvision.transforms.v2 import functional as F
41
42
43

            data = F.pil_to_tensor(data)

44
45
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        if tensor.ndim < 2:
46
            raise ValueError
47
48
        elif tensor.ndim == 2:
            tensor = tensor.unsqueeze(0)
49

50
        return cls._wrap(tensor)
51

52
    @classmethod
53
54
    def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
        return cls._wrap(tensor)
55

56
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
57
        return self._make_repr()
58

59
    def horizontal_flip(self) -> Image:
60
        output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
61
        return Image.wrap_like(self, output)
62
63

    def vertical_flip(self) -> Image:
64
        output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
65
        return Image.wrap_like(self, output)
66
67
68
69

    def resize(  # type: ignore[override]
        self,
        size: List[int],
70
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
71
        max_size: Optional[int] = None,
72
        antialias: Optional[Union[str, bool]] = "warn",
73
    ) -> Image:
74
        output = self._F.resize_image_tensor(
75
            self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
76
        )
77
        return Image.wrap_like(self, output)
78
79

    def crop(self, top: int, left: int, height: int, width: int) -> Image:
80
        output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
81
        return Image.wrap_like(self, output)
82
83

    def center_crop(self, output_size: List[int]) -> Image:
84
        output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
85
        return Image.wrap_like(self, output)
86
87
88
89
90
91
92
93

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
94
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
95
        antialias: Optional[Union[str, bool]] = "warn",
96
    ) -> Image:
97
        output = self._F.resized_crop_image_tensor(
98
99
100
101
102
103
104
105
            self.as_subclass(torch.Tensor),
            top,
            left,
            height,
            width,
            size=list(size),
            interpolation=interpolation,
            antialias=antialias,
106
        )
107
        return Image.wrap_like(self, output)
108
109

    def pad(
110
        self,
111
112
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
113
        padding_mode: str = "constant",
114
    ) -> Image:
115
        output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
116
        return Image.wrap_like(self, output)
117
118
119
120

    def rotate(
        self,
        angle: float,
121
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
122
123
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
124
        fill: _FillTypeJIT = None,
125
    ) -> Image:
126
127
        output = self._F.rotate_image_tensor(
            self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
128
        )
129
        return Image.wrap_like(self, output)
130
131
132

    def affine(
        self,
133
        angle: Union[int, float],
134
135
136
        translate: List[float],
        scale: float,
        shear: List[float],
137
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
138
        fill: _FillTypeJIT = None,
139
140
        center: Optional[List[float]] = None,
    ) -> Image:
141
142
        output = self._F.affine_image_tensor(
            self.as_subclass(torch.Tensor),
143
144
145
146
147
148
149
150
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
151
        return Image.wrap_like(self, output)
152
153
154

    def perspective(
        self,
155
156
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
157
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
158
        fill: _FillTypeJIT = None,
159
        coefficients: Optional[List[float]] = None,
160
    ) -> Image:
161
        output = self._F.perspective_image_tensor(
162
163
164
165
166
167
            self.as_subclass(torch.Tensor),
            startpoints,
            endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
168
        )
169
        return Image.wrap_like(self, output)
170

171
172
173
    def elastic(
        self,
        displacement: torch.Tensor,
174
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
175
        fill: _FillTypeJIT = None,
176
    ) -> Image:
177
178
179
        output = self._F.elastic_image_tensor(
            self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
        )
180
        return Image.wrap_like(self, output)
181

182
    def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
183
184
185
186
187
        output = self._F.rgb_to_grayscale_image_tensor(
            self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
        )
        return Image.wrap_like(self, output)

188
    def adjust_brightness(self, brightness_factor: float) -> Image:
189
190
191
        output = self._F.adjust_brightness_image_tensor(
            self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
        )
192
        return Image.wrap_like(self, output)
193
194

    def adjust_saturation(self, saturation_factor: float) -> Image:
195
196
197
        output = self._F.adjust_saturation_image_tensor(
            self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
        )
198
        return Image.wrap_like(self, output)
199
200

    def adjust_contrast(self, contrast_factor: float) -> Image:
201
        output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
202
        return Image.wrap_like(self, output)
203
204

    def adjust_sharpness(self, sharpness_factor: float) -> Image:
205
206
207
        output = self._F.adjust_sharpness_image_tensor(
            self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
        )
208
        return Image.wrap_like(self, output)
209
210

    def adjust_hue(self, hue_factor: float) -> Image:
211
        output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
212
        return Image.wrap_like(self, output)
213
214

    def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
215
        output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
216
        return Image.wrap_like(self, output)
217
218

    def posterize(self, bits: int) -> Image:
219
        output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
220
        return Image.wrap_like(self, output)
221
222

    def solarize(self, threshold: float) -> Image:
223
        output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
224
        return Image.wrap_like(self, output)
225
226

    def autocontrast(self) -> Image:
227
        output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
228
        return Image.wrap_like(self, output)
229
230

    def equalize(self) -> Image:
231
        output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
232
        return Image.wrap_like(self, output)
233
234

    def invert(self) -> Image:
235
        output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
236
        return Image.wrap_like(self, output)
237
238

    def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
239
240
241
        output = self._F.gaussian_blur_image_tensor(
            self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
        )
242
        return Image.wrap_like(self, output)
243

244
245
246
247
    def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
        output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
        return Image.wrap_like(self, output)

248

Philip Meier's avatar
Philip Meier committed
249
250
251
252
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor