_image.py 9.74 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, List, Optional, Tuple, Union
Philip Meier's avatar
Philip Meier committed
4

5
import PIL.Image
Philip Meier's avatar
Philip Meier committed
6
import torch
7
from torchvision.transforms.functional import InterpolationMode
Philip Meier's avatar
Philip Meier committed
8

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Image(Datapoint):
Philip Meier's avatar
Philip Meier committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    """[BETA] :class:`torch.Tensor` subclass for images.

    Args:
        data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
            well as PIL images.
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

26
    @classmethod
27
    def _wrap(cls, tensor: torch.Tensor) -> Image:
28
29
30
        image = tensor.as_subclass(cls)
        return image

31
32
33
34
    def __new__(
        cls,
        data: Any,
        *,
35
36
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
37
        requires_grad: Optional[bool] = None,
38
    ) -> Image:
39
        if isinstance(data, PIL.Image.Image):
40
            from torchvision.transforms.v2 import functional as F
41
42
43

            data = F.pil_to_tensor(data)

44
45
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        if tensor.ndim < 2:
46
            raise ValueError
47
48
        elif tensor.ndim == 2:
            tensor = tensor.unsqueeze(0)
49

50
        return cls._wrap(tensor)
51

52
    @classmethod
53
54
    def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
        return cls._wrap(tensor)
55

56
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
57
        return self._make_repr()
58

59
    @property
60
    def spatial_size(self) -> Tuple[int, int]:
61
        return tuple(self.shape[-2:])  # type: ignore[return-value]
62
63
64
65

    @property
    def num_channels(self) -> int:
        return self.shape[-3]
Philip Meier's avatar
Philip Meier committed
66

67
    def horizontal_flip(self) -> Image:
68
        output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
69
        return Image.wrap_like(self, output)
70
71

    def vertical_flip(self) -> Image:
72
        output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
73
        return Image.wrap_like(self, output)
74
75
76
77

    def resize(  # type: ignore[override]
        self,
        size: List[int],
78
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
79
        max_size: Optional[int] = None,
80
        antialias: Optional[Union[str, bool]] = "warn",
81
    ) -> Image:
82
        output = self._F.resize_image_tensor(
83
            self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
84
        )
85
        return Image.wrap_like(self, output)
86
87

    def crop(self, top: int, left: int, height: int, width: int) -> Image:
88
        output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
89
        return Image.wrap_like(self, output)
90
91

    def center_crop(self, output_size: List[int]) -> Image:
92
        output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
93
        return Image.wrap_like(self, output)
94
95
96
97
98
99
100
101

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
102
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
103
        antialias: Optional[Union[str, bool]] = "warn",
104
    ) -> Image:
105
        output = self._F.resized_crop_image_tensor(
106
107
108
109
110
111
112
113
            self.as_subclass(torch.Tensor),
            top,
            left,
            height,
            width,
            size=list(size),
            interpolation=interpolation,
            antialias=antialias,
114
        )
115
        return Image.wrap_like(self, output)
116
117

    def pad(
118
        self,
119
120
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
121
        padding_mode: str = "constant",
122
    ) -> Image:
123
        output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
124
        return Image.wrap_like(self, output)
125
126
127
128

    def rotate(
        self,
        angle: float,
129
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
130
131
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
132
        fill: _FillTypeJIT = None,
133
    ) -> Image:
134
135
        output = self._F.rotate_image_tensor(
            self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
136
        )
137
        return Image.wrap_like(self, output)
138
139
140

    def affine(
        self,
141
        angle: Union[int, float],
142
143
144
        translate: List[float],
        scale: float,
        shear: List[float],
145
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
146
        fill: _FillTypeJIT = None,
147
148
        center: Optional[List[float]] = None,
    ) -> Image:
149
150
        output = self._F.affine_image_tensor(
            self.as_subclass(torch.Tensor),
151
152
153
154
155
156
157
158
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
159
        return Image.wrap_like(self, output)
160
161
162

    def perspective(
        self,
163
164
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
165
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
166
        fill: _FillTypeJIT = None,
167
        coefficients: Optional[List[float]] = None,
168
    ) -> Image:
169
        output = self._F.perspective_image_tensor(
170
171
172
173
174
175
            self.as_subclass(torch.Tensor),
            startpoints,
            endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
176
        )
177
        return Image.wrap_like(self, output)
178

179
180
181
    def elastic(
        self,
        displacement: torch.Tensor,
182
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
183
        fill: _FillTypeJIT = None,
184
    ) -> Image:
185
186
187
        output = self._F.elastic_image_tensor(
            self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
        )
188
        return Image.wrap_like(self, output)
189

190
    def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
191
192
193
194
195
        output = self._F.rgb_to_grayscale_image_tensor(
            self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
        )
        return Image.wrap_like(self, output)

196
    def adjust_brightness(self, brightness_factor: float) -> Image:
197
198
199
        output = self._F.adjust_brightness_image_tensor(
            self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
        )
200
        return Image.wrap_like(self, output)
201
202

    def adjust_saturation(self, saturation_factor: float) -> Image:
203
204
205
        output = self._F.adjust_saturation_image_tensor(
            self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
        )
206
        return Image.wrap_like(self, output)
207
208

    def adjust_contrast(self, contrast_factor: float) -> Image:
209
        output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
210
        return Image.wrap_like(self, output)
211
212

    def adjust_sharpness(self, sharpness_factor: float) -> Image:
213
214
215
        output = self._F.adjust_sharpness_image_tensor(
            self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
        )
216
        return Image.wrap_like(self, output)
217
218

    def adjust_hue(self, hue_factor: float) -> Image:
219
        output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
220
        return Image.wrap_like(self, output)
221
222

    def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
223
        output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
224
        return Image.wrap_like(self, output)
225
226

    def posterize(self, bits: int) -> Image:
227
        output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
228
        return Image.wrap_like(self, output)
229
230

    def solarize(self, threshold: float) -> Image:
231
        output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
232
        return Image.wrap_like(self, output)
233
234

    def autocontrast(self) -> Image:
235
        output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
236
        return Image.wrap_like(self, output)
237
238

    def equalize(self) -> Image:
239
        output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
240
        return Image.wrap_like(self, output)
241
242

    def invert(self) -> Image:
243
        output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
244
        return Image.wrap_like(self, output)
245
246

    def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
247
248
249
        output = self._F.gaussian_blur_image_tensor(
            self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
        )
250
        return Image.wrap_like(self, output)
251

252
253
254
255
    def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
        output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
        return Image.wrap_like(self, output)

256

Philip Meier's avatar
Philip Meier committed
257
258
259
260
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor