_datapoint.py 6.12 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
Philip Meier's avatar
Philip Meier committed
4
5

import torch
6
from torch._C import DisableTorchFunctionSubclass
7
from torch.types import _device, _dtype, _size
Philip Meier's avatar
Philip Meier committed
8

9
10
from torchvision.datapoints._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass

11

12
D = TypeVar("D", bound="Datapoint")
13
14


15
class Datapoint(torch.Tensor):
16
17
18
19
    """[Beta] Base class for all datapoints.

    You probably don't want to use this class unless you're defining your own
    custom Datapoints. See
Nicolas Hug's avatar
Nicolas Hug committed
20
    :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for details.
21
22
    """

23
24
    @staticmethod
    def _to_tensor(
25
26
        data: Any,
        dtype: Optional[torch.dtype] = None,
27
        device: Optional[Union[torch.device, str, int]] = None,
28
        requires_grad: Optional[bool] = None,
29
    ) -> torch.Tensor:
30
31
        if requires_grad is None:
            requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
32
        return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
Philip Meier's avatar
Philip Meier committed
33

34
    @classmethod
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    def _wrap_output(
        cls,
        output: torch.Tensor,
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> torch.Tensor:
        # Same as torch._tensor._convert
        if isinstance(output, torch.Tensor) and not isinstance(output, cls):
            output = output.as_subclass(cls)

        if isinstance(output, (tuple, list)):
            # Also handles things like namedtuples
            output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
        return output
49

Philip Meier's avatar
Philip Meier committed
50
51
52
53
54
55
56
57
    @classmethod
    def __torch_function__(
        cls,
        func: Callable[..., torch.Tensor],
        types: Tuple[Type[torch.Tensor], ...],
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> torch.Tensor:
58
59
60
61
62
63
64
        """For general information about how the __torch_function__ protocol works,
        see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

        TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
        ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
        ``args`` and ``kwargs`` of the original call.

65
66
67
        Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type
        of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
        "Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
68

69
        Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
70
        """
71
72
73
        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

74
75
        # Like in the base Tensor.__torch_function__ implementation, it's easier to always use
        # DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
76
        with DisableTorchFunctionSubclass():
77
            output = func(*args, **kwargs or dict())
Philip Meier's avatar
Philip Meier committed
78

79
80
        must_return_subclass = _must_return_subclass()
        if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
81
82
83
84
85
86
87
            # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
            # in test_to_datapoint_reference().
            # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
            # the computation by walking the MRO upwards. For example,
            # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
            # `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
            # be wrapped into an `Image`.
88
            return cls._wrap_output(output, args, kwargs)
89

90
        if not must_return_subclass and isinstance(output, cls):
91
92
93
            # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
            # so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
            return output.as_subclass(torch.Tensor)
94

95
        return output
96

97
98
99
100
101
102
    def _make_repr(self, **kwargs: Any) -> str:
        # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
        # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
        extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
        return f"{super().__repr__()[:-1]}, {extra_repr})"

103
104
105
106
    # Add properties for common attributes like shape, dtype, device, ndim etc
    # this way we return the result without passing into __torch_function__
    @property
    def shape(self) -> _size:  # type: ignore[override]
107
        with DisableTorchFunctionSubclass():
108
109
110
111
            return super().shape

    @property
    def ndim(self) -> int:  # type: ignore[override]
112
        with DisableTorchFunctionSubclass():
113
114
115
116
            return super().ndim

    @property
    def device(self, *args: Any, **kwargs: Any) -> _device:  # type: ignore[override]
117
        with DisableTorchFunctionSubclass():
118
119
120
121
            return super().device

    @property
    def dtype(self) -> _dtype:  # type: ignore[override]
122
        with DisableTorchFunctionSubclass():
123
124
            return super().dtype

125
126
127
128
129
    def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
        # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
        # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
        # attribute is cleared, so we need to refill it before we return.
        # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
Philip Meier's avatar
Philip Meier committed
130
        # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
131
        # `BoundingBoxes.clone()`.
132
        return self.detach().clone().requires_grad_(self.requires_grad)  # type: ignore[return-value]