_utils.py 5.34 KB
Newer Older
1
import functools
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
3
4

import torch
5
from torchvision import datapoints
6

7
8
9
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]

10
11

def is_simple_tensor(inpt: Any) -> bool:
12
13
14
    return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)


Nicolas Hug's avatar
Nicolas Hug committed
15
# {functional: {input_type: type_specific_kernel}}
16
17
18
19
20
21
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}


def _kernel_datapoint_wrapper(kernel):
    @functools.wraps(kernel)
    def wrapper(inpt, *args, **kwargs):
22
23
        # We always pass datapoints as pure tensors to the kernels to avoid going through the
        # Tensor.__torch_function__ logic, which is costly.
24
25
26
27
28
29
        output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
        return type(inpt).wrap_like(inpt, output)

    return wrapper


Nicolas Hug's avatar
Nicolas Hug committed
30
31
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
    registry = _KERNEL_REGISTRY.setdefault(functional, {})
32
    if input_type in registry:
Nicolas Hug's avatar
Nicolas Hug committed
33
        raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
34
35

    def decorator(kernel):
36
37
38
39
40
        registry[input_type] = (
            _kernel_datapoint_wrapper(kernel)
            if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
            else kernel
        )
41
42
43
44
45
        return kernel

    return decorator


Nicolas Hug's avatar
Nicolas Hug committed
46
def _name_to_functional(name):
47
48
49
50
51
    import torchvision.transforms.v2.functional  # noqa

    try:
        return getattr(torchvision.transforms.v2.functional, name)
    except AttributeError:
52
        raise ValueError(
Nicolas Hug's avatar
Nicolas Hug committed
53
            f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
54
        ) from None
55
56


57
58
59
60
61
_BUILTIN_DATAPOINT_TYPES = {
    obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}


Nicolas Hug's avatar
Nicolas Hug committed
62
63
def register_kernel(functional, datapoint_cls):
    """Decorate a kernel to register it for a functional and a (custom) datapoint type.
64
65
66
67

    See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
    details.
    """
Nicolas Hug's avatar
Nicolas Hug committed
68
69
    if isinstance(functional, str):
        functional = _name_to_functional(name=functional)
70
    elif not (
Nicolas Hug's avatar
Nicolas Hug committed
71
72
        callable(functional)
        and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
73
74
    ):
        raise ValueError(
Nicolas Hug's avatar
Nicolas Hug committed
75
76
            f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
            f"but got {functional}."
77
78
        )

79
    if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
80
81
82
83
84
        raise ValueError(
            f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
            f"but got {datapoint_cls}."
        )

85
86
87
    if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
        raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")

Nicolas Hug's avatar
Nicolas Hug committed
88
    return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
89
90


Nicolas Hug's avatar
Nicolas Hug committed
91
92
def _get_kernel(functional, input_type, *, allow_passthrough=False):
    registry = _KERNEL_REGISTRY.get(functional)
93
    if not registry:
Nicolas Hug's avatar
Nicolas Hug committed
94
        raise ValueError(f"No kernel registered for functional {functional.__name__}.")
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    # In case we have an exact type match, we take a shortcut.
    if input_type in registry:
        return registry[input_type]

    # In case of datapoints, we check if we have a kernel for a superclass registered
    if issubclass(input_type, datapoints.Datapoint):
        # Since we have already checked for an exact match above, we can start the traversal at the superclass.
        for cls in input_type.__mro__[1:]:
            if cls is datapoints.Datapoint:
                # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
                # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
                # allow kernels to be registered for datapoints.Datapoint anyway.
                break
            elif cls in registry:
                return registry[cls]

112
113
    if allow_passthrough:
        return lambda inpt, *args, **kwargs: inpt
114
115

    raise TypeError(
Nicolas Hug's avatar
Nicolas Hug committed
116
        f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
117
118
        f"but got {input_type} instead."
    )
119
120
121


# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
Nicolas Hug's avatar
Nicolas Hug committed
122
123
124
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(functional, input_type):
    registry = _KERNEL_REGISTRY.setdefault(functional, {})
125
    if input_type in registry:
Nicolas Hug's avatar
Nicolas Hug committed
126
        raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
127
128
129
130
131
132
133
134
135
136
137

    def wrap(kernel):
        @functools.wraps(kernel)
        def wrapper(inpt, *args, **kwargs):
            output = kernel(inpt, *args, **kwargs)
            container_type = type(output)
            return container_type(type(inpt).wrap_like(inpt, o) for o in output)

        return wrapper

    def decorator(kernel):
138
        registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
139
140
141
        return kernel

    return decorator