_datapoint.py 6.21 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
Philip Meier's avatar
Philip Meier committed
4

5
import PIL.Image
Philip Meier's avatar
Philip Meier committed
6
import torch
7
from torch._C import DisableTorchFunctionSubclass
8
from torch.types import _device, _dtype, _size
Philip Meier's avatar
Philip Meier committed
9

10

11
D = TypeVar("D", bound="Datapoint")
Philip Meier's avatar
Philip Meier committed
12
13
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
14
15


16
class Datapoint(torch.Tensor):
17
18
    @staticmethod
    def _to_tensor(
19
20
        data: Any,
        dtype: Optional[torch.dtype] = None,
21
        device: Optional[Union[torch.device, str, int]] = None,
22
        requires_grad: Optional[bool] = None,
23
    ) -> torch.Tensor:
24
25
        if requires_grad is None:
            requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
26
        return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
Philip Meier's avatar
Philip Meier committed
27

28
    @classmethod
29
    def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
30
        raise NotImplementedError
Philip Meier's avatar
Philip Meier committed
31

32
    _NO_WRAPPING_EXCEPTIONS = {
33
34
        torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
        torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
35
        torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
36
37
38
39
40
        # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
        # retains the type automatically
        torch.Tensor.requires_grad_: lambda cls, input, output: output,
    }

Philip Meier's avatar
Philip Meier committed
41
42
43
44
45
46
47
48
    @classmethod
    def __torch_function__(
        cls,
        func: Callable[..., torch.Tensor],
        types: Tuple[Type[torch.Tensor], ...],
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> torch.Tensor:
49
50
51
52
53
54
55
        """For general information about how the __torch_function__ protocol works,
        see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

        TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
        ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
        ``args`` and ``kwargs`` of the original call.

56
        The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
57
58
        use case, this has two downsides:

59
        1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
60
61
62
           ``return cls(func(*args, **kwargs))``, will fail for them.
        2. For most operations, there is no way of knowing if the input type is still valid for the output.

63
        For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
64
        listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
65
        """
66
67
68
69
70
71
        # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
        # need to reimplement the functionality.

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

72
        with DisableTorchFunctionSubclass():
73
            output = func(*args, **kwargs or dict())
Philip Meier's avatar
Philip Meier committed
74

75
76
77
78
            wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
            # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
            # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
            # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
79
80
81
            # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
            # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
            # be wrapped into a `datapoints.Image`.
82
            if wrapper and isinstance(args[0], cls):
83
                return wrapper(cls, args[0], output)
84
85
86
87

            # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
            # will retain the input type. Thus, we need to unwrap here.
            if isinstance(output, cls):
88
                return output.as_subclass(torch.Tensor)
89

90
            return output
91

92
93
94
95
96
97
    def _make_repr(self, **kwargs: Any) -> str:
        # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
        # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
        extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
        return f"{super().__repr__()[:-1]}, {extra_repr})"

98
99
100
101
    # Add properties for common attributes like shape, dtype, device, ndim etc
    # this way we return the result without passing into __torch_function__
    @property
    def shape(self) -> _size:  # type: ignore[override]
102
        with DisableTorchFunctionSubclass():
103
104
105
106
            return super().shape

    @property
    def ndim(self) -> int:  # type: ignore[override]
107
        with DisableTorchFunctionSubclass():
108
109
110
111
            return super().ndim

    @property
    def device(self, *args: Any, **kwargs: Any) -> _device:  # type: ignore[override]
112
        with DisableTorchFunctionSubclass():
113
114
115
116
            return super().device

    @property
    def dtype(self) -> _dtype:  # type: ignore[override]
117
        with DisableTorchFunctionSubclass():
118
119
            return super().dtype

120
121
122
123
124
    def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
        # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
        # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
        # attribute is cleared, so we need to refill it before we return.
        # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
Philip Meier's avatar
Philip Meier committed
125
        # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
126
        # `BoundingBoxes.clone()`.
127
128
        return self.detach().clone().requires_grad_(self.requires_grad)  # type: ignore[return-value]

129

Philip Meier's avatar
Philip Meier committed
130
131
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor