_torch_function_helpers.py 1.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

_TORCHFUNCTION_SUBCLASS = False


class _ReturnTypeCM:
    def __init__(self, to_restore):
        self.to_restore = to_restore

    def __enter__(self):
        return self

    def __exit__(self, *args):
        global _TORCHFUNCTION_SUBCLASS
        _TORCHFUNCTION_SUBCLASS = self.to_restore


def set_return_type(return_type: str):
    """Set the return type of torch operations on datapoints.

21
22
23
24
    This only affects the behaviour of torch operations. It has no effect on
    ``torchvision`` transforms or functionals, which will always return as
    output the same type that was passed as input.

25
26
27
28
29
    Can be used as a global flag for the entire program:

    .. code:: python

        img = datapoints.Image(torch.rand(3, 5, 5))
30
31
32
        img + 2  # This is a pure Tensor (default behaviour)

        set_return_type("datapoints")
33
34
35
36
37
38
39
        img + 2  # This is an Image

    or as a context manager to restrict the scope:

    .. code:: python

        img = datapoints.Image(torch.rand(3, 5, 5))
40
        img + 2  # This is a pure Tensor
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        with set_return_type("datapoints"):
            img + 2  # This is an Image
        img + 2  # This is a pure Tensor

    Args:
        return_type (str): Can be "datapoint" or "tensor". Default is "tensor".
    """
    global _TORCHFUNCTION_SUBCLASS
    to_restore = _TORCHFUNCTION_SUBCLASS
    _TORCHFUNCTION_SUBCLASS = {"tensor": False, "datapoint": True}[return_type.lower()]

    return _ReturnTypeCM(to_restore)


def _must_return_subclass():
    return _TORCHFUNCTION_SUBCLASS


# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}