_utils.py 5.37 KB
Newer Older
1
import functools
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
3
4

import torch
5
from torchvision import datapoints
6

7
8
9
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]

10

11
def is_pure_tensor(inpt: Any) -> bool:
12
13
14
    return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)


Nicolas Hug's avatar
Nicolas Hug committed
15
# {functional: {input_type: type_specific_kernel}}
16
17
18
19
20
21
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}


def _kernel_datapoint_wrapper(kernel):
    @functools.wraps(kernel)
    def wrapper(inpt, *args, **kwargs):
22
23
24
25
26
27
        # If you're wondering whether we could / should get rid of this wrapper,
        # the answer is no: we want to pass pure Tensors to avoid the overhead
        # of the __torch_function__ machinery. Note that this is always valid,
        # regardless of whether we override __torch_function__ in our base class
        # or not.
        # Also, even if we didn't call `as_subclass` here, we would still need
28
        # this wrapper to call wrap(), because the Datapoint type would be
29
30
        # lost after the first operation due to our own __torch_function__
        # logic.
31
        output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
32
        return datapoints.wrap(output, like=inpt)
33
34
35
36

    return wrapper


Nicolas Hug's avatar
Nicolas Hug committed
37
38
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
    registry = _KERNEL_REGISTRY.setdefault(functional, {})
39
    if input_type in registry:
Nicolas Hug's avatar
Nicolas Hug committed
40
        raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
41
42

    def decorator(kernel):
43
44
45
46
47
        registry[input_type] = (
            _kernel_datapoint_wrapper(kernel)
            if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
            else kernel
        )
48
49
50
51
52
        return kernel

    return decorator


Nicolas Hug's avatar
Nicolas Hug committed
53
def _name_to_functional(name):
54
55
56
57
58
    import torchvision.transforms.v2.functional  # noqa

    try:
        return getattr(torchvision.transforms.v2.functional, name)
    except AttributeError:
59
        raise ValueError(
Nicolas Hug's avatar
Nicolas Hug committed
60
            f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
61
        ) from None
62
63


64
65
66
67
68
_BUILTIN_DATAPOINT_TYPES = {
    obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}


Nicolas Hug's avatar
Nicolas Hug committed
69
def register_kernel(functional, datapoint_cls):
Nicolas Hug's avatar
Nicolas Hug committed
70
    """[BETA] Decorate a kernel to register it for a functional and a (custom) datapoint type.
71

Nicolas Hug's avatar
Nicolas Hug committed
72
    See :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for usage
73
74
    details.
    """
Nicolas Hug's avatar
Nicolas Hug committed
75
76
    if isinstance(functional, str):
        functional = _name_to_functional(name=functional)
77
    elif not (
Nicolas Hug's avatar
Nicolas Hug committed
78
79
        callable(functional)
        and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
80
81
    ):
        raise ValueError(
Nicolas Hug's avatar
Nicolas Hug committed
82
83
            f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
            f"but got {functional}."
84
85
        )

86
    if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
87
88
89
90
91
        raise ValueError(
            f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
            f"but got {datapoint_cls}."
        )

92
93
94
    if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
        raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")

Nicolas Hug's avatar
Nicolas Hug committed
95
    return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
96
97


Nicolas Hug's avatar
Nicolas Hug committed
98
99
def _get_kernel(functional, input_type, *, allow_passthrough=False):
    registry = _KERNEL_REGISTRY.get(functional)
100
    if not registry:
Nicolas Hug's avatar
Nicolas Hug committed
101
        raise ValueError(f"No kernel registered for functional {functional.__name__}.")
102

103
104
105
106
107
108
109
110
    for cls in input_type.__mro__:
        if cls in registry:
            return registry[cls]
        elif cls is datapoints.Datapoint:
            # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
            # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
            # allow kernels to be registered for datapoints.Datapoint anyway.
            break
111

112
113
    if allow_passthrough:
        return lambda inpt, *args, **kwargs: inpt
114
115

    raise TypeError(
Nicolas Hug's avatar
Nicolas Hug committed
116
        f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
117
118
        f"but got {input_type} instead."
    )
119
120
121


# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
Nicolas Hug's avatar
Nicolas Hug committed
122
123
124
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(functional, input_type):
    registry = _KERNEL_REGISTRY.setdefault(functional, {})
125
    if input_type in registry:
Nicolas Hug's avatar
Nicolas Hug committed
126
        raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
127
128
129
130
131
132

    def wrap(kernel):
        @functools.wraps(kernel)
        def wrapper(inpt, *args, **kwargs):
            output = kernel(inpt, *args, **kwargs)
            container_type = type(output)
133
            return container_type(datapoints.wrap(o, like=inpt) for o in output)
134
135
136
137

        return wrapper

    def decorator(kernel):
138
        registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
139
140
141
        return kernel

    return decorator