_torch_function_helpers.py 2.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

_TORCHFUNCTION_SUBCLASS = False


class _ReturnTypeCM:
    def __init__(self, to_restore):
        self.to_restore = to_restore

    def __enter__(self):
        return self

    def __exit__(self, *args):
        global _TORCHFUNCTION_SUBCLASS
        _TORCHFUNCTION_SUBCLASS = self.to_restore


def set_return_type(return_type: str):
Nicolas Hug's avatar
Nicolas Hug committed
19
    """[BETA] Set the return type of torch operations on datapoints.
20

21
22
23
24
    This only affects the behaviour of torch operations. It has no effect on
    ``torchvision`` transforms or functionals, which will always return as
    output the same type that was passed as input.

Nicolas Hug's avatar
Nicolas Hug committed
25
26
27
28
29
30
31
    .. warning::

        We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
        the end of your transform pipelines if you use
        ``set_return_type("dataptoint")``. This will avoid the
        ``__torch_function__`` overhead in the models ``forward()``.

32
33
34
35
36
    Can be used as a global flag for the entire program:

    .. code:: python

        img = datapoints.Image(torch.rand(3, 5, 5))
37
38
39
        img + 2  # This is a pure Tensor (default behaviour)

        set_return_type("datapoints")
40
41
42
43
44
45
46
        img + 2  # This is an Image

    or as a context manager to restrict the scope:

    .. code:: python

        img = datapoints.Image(torch.rand(3, 5, 5))
47
        img + 2  # This is a pure Tensor
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        with set_return_type("datapoints"):
            img + 2  # This is an Image
        img + 2  # This is a pure Tensor

    Args:
        return_type (str): Can be "datapoint" or "tensor". Default is "tensor".
    """
    global _TORCHFUNCTION_SUBCLASS
    to_restore = _TORCHFUNCTION_SUBCLASS
    _TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[return_type.lower()]

    return _ReturnTypeCM(to_restore)


def _must_return_subclass():
    return _TORCHFUNCTION_SUBCLASS


# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}