_utils.py 7.4 KB
Newer Older
1
2
3
import functools
import warnings
from typing import Any, Callable, Dict, Type
4
5

import torch
6
from torchvision import datapoints
7
8
9


def is_simple_tensor(inpt: Any) -> bool:
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)


# {dispatcher: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}


def _kernel_datapoint_wrapper(kernel):
    @functools.wraps(kernel)
    def wrapper(inpt, *args, **kwargs):
        output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
        return type(inpt).wrap_like(inpt, output)

    return wrapper


26
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
27
    registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
28
29
    if input_type in registry:
        raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
30
31

    def decorator(kernel):
32
33
34
35
36
        registry[input_type] = (
            _kernel_datapoint_wrapper(kernel)
            if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
            else kernel
        )
37
38
39
40
41
        return kernel

    return decorator


42
43
44
45
46
47
def _name_to_dispatcher(name):
    import torchvision.transforms.v2.functional  # noqa

    try:
        return getattr(torchvision.transforms.v2.functional, name)
    except AttributeError:
48
49
50
        raise ValueError(
            f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
        ) from None
51
52


53
def register_kernel(dispatcher, datapoint_cls):
54
55
56
57
58
    """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.

    See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
    details.
    """
59
60
    if isinstance(dispatcher, str):
        dispatcher = _name_to_dispatcher(name=dispatcher)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    elif not (
        callable(dispatcher)
        and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
    ):
        raise ValueError(
            f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
            f"but got {dispatcher}."
        )

    if not (
        isinstance(datapoint_cls, type)
        and issubclass(datapoint_cls, datapoints.Datapoint)
        and datapoint_cls is not datapoints.Datapoint
    ):
        raise ValueError(
            f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
            f"but got {datapoint_cls}."
        )

80
81
82
    return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)


83
def _get_kernel(dispatcher, input_type):
84
85
    registry = _KERNEL_REGISTRY.get(dispatcher)
    if not registry:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")

    # In case we have an exact type match, we take a shortcut.
    if input_type in registry:
        return registry[input_type]

    # In case of datapoints, we check if we have a kernel for a superclass registered
    if issubclass(input_type, datapoints.Datapoint):
        # Since we have already checked for an exact match above, we can start the traversal at the superclass.
        for cls in input_type.__mro__[1:]:
            if cls is datapoints.Datapoint:
                # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
                # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
                # allow kernels to be registered for datapoints.Datapoint anyway.
                break
            elif cls in registry:
                return registry[cls]

        # Note that in the future we are not going to return a noop here, but rather raise the error below
        return _noop

    raise TypeError(
        f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, "
        f"and subclasses of torchvision.datapoints.Datapoint, "
        f"but got {input_type} instead."
    )
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142


# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.


# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
    """
    Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
    from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.

    For example, without explicit no-op registration the following would be valid user code:

    .. code::
        from torchvision.transforms.v2 import functional as F

        @F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
        def lol(...):
            ...
    """

    def decorator(dispatcher):
        for cls in datapoints_classes:
            msg = (
                f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
                f"This will likely change in the future."
            )
143
144
145
            _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
                functools.partial(_noop, __msg__=msg if warn_passthrough else None)
            )
146
147
148
149
150
151
152
153
154
155
156
157
158
        return dispatcher

    return decorator


def _noop(inpt, *args, __msg__=None, **kwargs):
    if __msg__:
        warnings.warn(__msg__, UserWarning, stacklevel=2)
    return inpt


# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
159
def _register_unsupported_type(*input_types):
160
161
162
163
    def kernel(inpt, *args, __dispatcher_name__, **kwargs):
        raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")

    def decorator(dispatcher):
164
165
166
167
        for input_type in input_types:
            _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
                functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
            )
168
169
170
171
172
173
174
        return dispatcher

    return decorator


# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
175
def _register_five_ten_crop_kernel(dispatcher, input_type):
176
    registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
177
178
    if input_type in registry:
        raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
179
180
181
182
183
184
185
186
187
188
189

    def wrap(kernel):
        @functools.wraps(kernel)
        def wrapper(inpt, *args, **kwargs):
            output = kernel(inpt, *args, **kwargs)
            container_type = type(output)
            return container_type(type(inpt).wrap_like(inpt, o) for o in output)

        return wrapper

    def decorator(kernel):
190
        registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
191
192
193
        return kernel

    return decorator