_datapoint.py 10.4 KB
Newer Older
1
2
from __future__ import annotations

3
from types import ModuleType
4
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
Philip Meier's avatar
Philip Meier committed
5

6
import PIL.Image
Philip Meier's avatar
Philip Meier committed
7
import torch
8
from torch._C import DisableTorchFunctionSubclass
9
from torch.types import _device, _dtype, _size
10
from torchvision.transforms import InterpolationMode
Philip Meier's avatar
Philip Meier committed
11

12

13
D = TypeVar("D", bound="Datapoint")
Philip Meier's avatar
Philip Meier committed
14
15
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
16
17


18
class Datapoint(torch.Tensor):
19
20
    __F: Optional[ModuleType] = None

21
22
    @staticmethod
    def _to_tensor(
23
24
        data: Any,
        dtype: Optional[torch.dtype] = None,
25
        device: Optional[Union[torch.device, str, int]] = None,
26
        requires_grad: Optional[bool] = None,
27
    ) -> torch.Tensor:
28
29
        if requires_grad is None:
            requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
30
        return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
Philip Meier's avatar
Philip Meier committed
31

32
    @classmethod
33
    def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
34
        raise NotImplementedError
Philip Meier's avatar
Philip Meier committed
35

36
    _NO_WRAPPING_EXCEPTIONS = {
37
38
        torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
        torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
39
        torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
40
41
42
43
44
        # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
        # retains the type automatically
        torch.Tensor.requires_grad_: lambda cls, input, output: output,
    }

Philip Meier's avatar
Philip Meier committed
45
46
47
48
49
50
51
52
    @classmethod
    def __torch_function__(
        cls,
        func: Callable[..., torch.Tensor],
        types: Tuple[Type[torch.Tensor], ...],
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> torch.Tensor:
53
54
55
56
57
58
59
        """For general information about how the __torch_function__ protocol works,
        see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

        TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
        ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
        ``args`` and ``kwargs`` of the original call.

60
        The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
61
62
        use case, this has two downsides:

63
        1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
64
65
66
           ``return cls(func(*args, **kwargs))``, will fail for them.
        2. For most operations, there is no way of knowing if the input type is still valid for the output.

67
        For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
68
        listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
69
        """
70
71
72
73
74
75
        # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
        # need to reimplement the functionality.

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

76
        with DisableTorchFunctionSubclass():
77
            output = func(*args, **kwargs or dict())
Philip Meier's avatar
Philip Meier committed
78

79
80
81
82
            wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
            # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
            # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
            # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
83
84
85
            # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
            # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
            # be wrapped into a `datapoints.Image`.
86
            if wrapper and isinstance(args[0], cls):
87
                return wrapper(cls, args[0], output)
88
89
90
91

            # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
            # will retain the input type. Thus, we need to unwrap here.
            if isinstance(output, cls):
92
                return output.as_subclass(torch.Tensor)
93

94
            return output
95

96
97
98
99
100
101
    def _make_repr(self, **kwargs: Any) -> str:
        # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
        # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
        extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
        return f"{super().__repr__()[:-1]}, {extra_repr})"

102
103
104
105
106
107
    @property
    def _F(self) -> ModuleType:
        # This implements a lazy import of the functional to get around the cyclic import. This import is deferred
        # until the first time we need reference to the functional module and it's shared across all instances of
        # the class. This approach avoids the DataLoader issue described at
        # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
108
        if Datapoint.__F is None:
109
            from ..transforms.v2 import functional
110

111
112
            Datapoint.__F = functional
        return Datapoint.__F
113

114
115
116
117
    # Add properties for common attributes like shape, dtype, device, ndim etc
    # this way we return the result without passing into __torch_function__
    @property
    def shape(self) -> _size:  # type: ignore[override]
118
        with DisableTorchFunctionSubclass():
119
120
121
122
            return super().shape

    @property
    def ndim(self) -> int:  # type: ignore[override]
123
        with DisableTorchFunctionSubclass():
124
125
126
127
            return super().ndim

    @property
    def device(self, *args: Any, **kwargs: Any) -> _device:  # type: ignore[override]
128
        with DisableTorchFunctionSubclass():
129
130
131
132
            return super().device

    @property
    def dtype(self) -> _dtype:  # type: ignore[override]
133
        with DisableTorchFunctionSubclass():
134
135
            return super().dtype

136
137
138
139
140
141
142
143
144
    def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
        # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
        # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
        # attribute is cleared, so we need to refill it before we return.
        # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
        # `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by
        # `BoundingBox.clone()`.
        return self.detach().clone().requires_grad_(self.requires_grad)  # type: ignore[return-value]

145
    def horizontal_flip(self) -> Datapoint:
146
147
        return self

148
    def vertical_flip(self) -> Datapoint:
149
150
151
152
153
154
155
        return self

    # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
    # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
    def resize(  # type: ignore[override]
        self,
        size: List[int],
156
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
157
        max_size: Optional[int] = None,
158
        antialias: Optional[Union[str, bool]] = "warn",
159
    ) -> Datapoint:
160
161
        return self

162
    def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
163
164
        return self

165
    def center_crop(self, output_size: List[int]) -> Datapoint:
166
167
168
169
170
171
172
173
174
        return self

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
175
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
176
        antialias: Optional[Union[str, bool]] = "warn",
177
    ) -> Datapoint:
178
179
180
        return self

    def pad(
181
        self,
182
183
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
184
        padding_mode: str = "constant",
185
    ) -> Datapoint:
186
187
188
189
190
        return self

    def rotate(
        self,
        angle: float,
191
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
192
193
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
194
        fill: _FillTypeJIT = None,
195
    ) -> Datapoint:
196
197
198
199
        return self

    def affine(
        self,
200
        angle: Union[int, float],
201
202
203
        translate: List[float],
        scale: float,
        shear: List[float],
204
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
205
        fill: _FillTypeJIT = None,
206
        center: Optional[List[float]] = None,
207
    ) -> Datapoint:
208
209
210
211
        return self

    def perspective(
        self,
212
213
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
214
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
215
        fill: _FillTypeJIT = None,
216
        coefficients: Optional[List[float]] = None,
217
    ) -> Datapoint:
218
219
        return self

220
221
222
    def elastic(
        self,
        displacement: torch.Tensor,
223
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
224
        fill: _FillTypeJIT = None,
225
    ) -> Datapoint:
226
227
        return self

228
    def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
229
230
        return self

231
    def adjust_brightness(self, brightness_factor: float) -> Datapoint:
232
233
        return self

234
    def adjust_saturation(self, saturation_factor: float) -> Datapoint:
235
236
        return self

237
    def adjust_contrast(self, contrast_factor: float) -> Datapoint:
238
239
        return self

240
    def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
241
242
        return self

243
    def adjust_hue(self, hue_factor: float) -> Datapoint:
244
245
        return self

246
    def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
247
248
        return self

249
    def posterize(self, bits: int) -> Datapoint:
250
251
        return self

252
    def solarize(self, threshold: float) -> Datapoint:
253
254
        return self

255
    def autocontrast(self) -> Datapoint:
256
257
        return self

258
    def equalize(self) -> Datapoint:
259
260
        return self

261
    def invert(self) -> Datapoint:
262
        return self
263

264
    def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
265
        return self
266
267


Philip Meier's avatar
Philip Meier committed
268
269
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor