_datapoint.py 9.64 KB
Newer Older
1
2
from __future__ import annotations

3
from types import ModuleType
4
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
Philip Meier's avatar
Philip Meier committed
5

6
import PIL.Image
Philip Meier's avatar
Philip Meier committed
7
import torch
8
from torch._C import DisableTorchFunctionSubclass
9
from torch.types import _device, _dtype, _size
10
from torchvision.transforms import InterpolationMode
Philip Meier's avatar
Philip Meier committed
11

12

13
D = TypeVar("D", bound="Datapoint")
Philip Meier's avatar
Philip Meier committed
14
15
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
16
17


18
class Datapoint(torch.Tensor):
19
20
    __F: Optional[ModuleType] = None

21
22
    @staticmethod
    def _to_tensor(
23
24
        data: Any,
        dtype: Optional[torch.dtype] = None,
25
        device: Optional[Union[torch.device, str, int]] = None,
26
        requires_grad: Optional[bool] = None,
27
    ) -> torch.Tensor:
28
29
        if requires_grad is None:
            requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
30
        return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
Philip Meier's avatar
Philip Meier committed
31

32
    @classmethod
33
    def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
34
        raise NotImplementedError
Philip Meier's avatar
Philip Meier committed
35

36
    _NO_WRAPPING_EXCEPTIONS = {
37
38
        torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
        torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
39
40
41
42
43
        # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
        # retains the type automatically
        torch.Tensor.requires_grad_: lambda cls, input, output: output,
    }

Philip Meier's avatar
Philip Meier committed
44
45
46
47
48
49
50
51
    @classmethod
    def __torch_function__(
        cls,
        func: Callable[..., torch.Tensor],
        types: Tuple[Type[torch.Tensor], ...],
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> torch.Tensor:
52
53
54
55
56
57
58
        """For general information about how the __torch_function__ protocol works,
        see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

        TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
        ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
        ``args`` and ``kwargs`` of the original call.

59
        The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
60
61
        use case, this has two downsides:

62
        1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
63
64
65
           ``return cls(func(*args, **kwargs))``, will fail for them.
        2. For most operations, there is no way of knowing if the input type is still valid for the output.

66
        For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
67
        listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
68
        """
69
70
71
72
73
74
        # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
        # need to reimplement the functionality.

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

75
        with DisableTorchFunctionSubclass():
76
            output = func(*args, **kwargs or dict())
Philip Meier's avatar
Philip Meier committed
77

78
79
80
81
            wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
            # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
            # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
            # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
82
83
84
            # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
            # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
            # be wrapped into a `datapoints.Image`.
85
            if wrapper and isinstance(args[0], cls):
86
                return wrapper(cls, args[0], output)
87
88
89
90

            # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
            # will retain the input type. Thus, we need to unwrap here.
            if isinstance(output, cls):
91
                return output.as_subclass(torch.Tensor)
92

93
            return output
94

95
96
97
98
99
100
    def _make_repr(self, **kwargs: Any) -> str:
        # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
        # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
        extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
        return f"{super().__repr__()[:-1]}, {extra_repr})"

101
102
103
104
105
106
    @property
    def _F(self) -> ModuleType:
        # This implements a lazy import of the functional to get around the cyclic import. This import is deferred
        # until the first time we need reference to the functional module and it's shared across all instances of
        # the class. This approach avoids the DataLoader issue described at
        # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
107
        if Datapoint.__F is None:
108
            from ..transforms.v2 import functional
109

110
111
            Datapoint.__F = functional
        return Datapoint.__F
112

113
114
115
116
    # Add properties for common attributes like shape, dtype, device, ndim etc
    # this way we return the result without passing into __torch_function__
    @property
    def shape(self) -> _size:  # type: ignore[override]
117
        with DisableTorchFunctionSubclass():
118
119
120
121
            return super().shape

    @property
    def ndim(self) -> int:  # type: ignore[override]
122
        with DisableTorchFunctionSubclass():
123
124
125
126
            return super().ndim

    @property
    def device(self, *args: Any, **kwargs: Any) -> _device:  # type: ignore[override]
127
        with DisableTorchFunctionSubclass():
128
129
130
131
            return super().device

    @property
    def dtype(self) -> _dtype:  # type: ignore[override]
132
        with DisableTorchFunctionSubclass():
133
134
            return super().dtype

135
    def horizontal_flip(self) -> Datapoint:
136
137
        return self

138
    def vertical_flip(self) -> Datapoint:
139
140
141
142
143
144
145
        return self

    # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
    # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
    def resize(  # type: ignore[override]
        self,
        size: List[int],
146
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
147
        max_size: Optional[int] = None,
148
        antialias: Optional[Union[str, bool]] = "warn",
149
    ) -> Datapoint:
150
151
        return self

152
    def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
153
154
        return self

155
    def center_crop(self, output_size: List[int]) -> Datapoint:
156
157
158
159
160
161
162
163
164
        return self

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
165
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
166
        antialias: Optional[Union[str, bool]] = "warn",
167
    ) -> Datapoint:
168
169
170
        return self

    def pad(
171
        self,
172
173
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
174
        padding_mode: str = "constant",
175
    ) -> Datapoint:
176
177
178
179
180
        return self

    def rotate(
        self,
        angle: float,
181
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
182
183
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
184
        fill: _FillTypeJIT = None,
185
    ) -> Datapoint:
186
187
188
189
        return self

    def affine(
        self,
190
        angle: Union[int, float],
191
192
193
        translate: List[float],
        scale: float,
        shear: List[float],
194
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
195
        fill: _FillTypeJIT = None,
196
        center: Optional[List[float]] = None,
197
    ) -> Datapoint:
198
199
200
201
        return self

    def perspective(
        self,
202
203
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
204
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
205
        fill: _FillTypeJIT = None,
206
        coefficients: Optional[List[float]] = None,
207
    ) -> Datapoint:
208
209
        return self

210
211
212
    def elastic(
        self,
        displacement: torch.Tensor,
213
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
214
        fill: _FillTypeJIT = None,
215
    ) -> Datapoint:
216
217
        return self

218
    def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
219
220
        return self

221
    def adjust_brightness(self, brightness_factor: float) -> Datapoint:
222
223
        return self

224
    def adjust_saturation(self, saturation_factor: float) -> Datapoint:
225
226
        return self

227
    def adjust_contrast(self, contrast_factor: float) -> Datapoint:
228
229
        return self

230
    def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
231
232
        return self

233
    def adjust_hue(self, hue_factor: float) -> Datapoint:
234
235
        return self

236
    def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
237
238
        return self

239
    def posterize(self, bits: int) -> Datapoint:
240
241
        return self

242
    def solarize(self, threshold: float) -> Datapoint:
243
244
        return self

245
    def autocontrast(self) -> Datapoint:
246
247
        return self

248
    def equalize(self) -> Datapoint:
249
250
        return self

251
    def invert(self) -> Datapoint:
252
        return self
253

254
    def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
255
        return self
256
257


Philip Meier's avatar
Philip Meier committed
258
259
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor