_utils.py 5.19 KB
Newer Older
1
import functools
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
3
4

import torch
5
from torchvision import datapoints
6

7
8
9
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]

10
11

def is_simple_tensor(inpt: Any) -> bool:
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)


# {dispatcher: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}


def _kernel_datapoint_wrapper(kernel):
    @functools.wraps(kernel)
    def wrapper(inpt, *args, **kwargs):
        output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
        return type(inpt).wrap_like(inpt, output)

    return wrapper


28
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
29
    registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
30
31
    if input_type in registry:
        raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
32
33

    def decorator(kernel):
34
35
36
37
38
        registry[input_type] = (
            _kernel_datapoint_wrapper(kernel)
            if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
            else kernel
        )
39
40
41
42
43
        return kernel

    return decorator


44
45
46
47
48
49
def _name_to_dispatcher(name):
    import torchvision.transforms.v2.functional  # noqa

    try:
        return getattr(torchvision.transforms.v2.functional, name)
    except AttributeError:
50
51
52
        raise ValueError(
            f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
        ) from None
53
54


55
56
57
58
59
_BUILTIN_DATAPOINT_TYPES = {
    obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}


60
def register_kernel(dispatcher, datapoint_cls):
61
62
63
64
65
    """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.

    See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
    details.
    """
66
67
    if isinstance(dispatcher, str):
        dispatcher = _name_to_dispatcher(name=dispatcher)
68
69
70
71
72
73
74
75
76
    elif not (
        callable(dispatcher)
        and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
    ):
        raise ValueError(
            f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
            f"but got {dispatcher}."
        )

77
    if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
78
79
80
81
82
        raise ValueError(
            f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
            f"but got {datapoint_cls}."
        )

83
84
85
    if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
        raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")

86
87
88
    return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)


89
def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
90
91
    registry = _KERNEL_REGISTRY.get(dispatcher)
    if not registry:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")

    # In case we have an exact type match, we take a shortcut.
    if input_type in registry:
        return registry[input_type]

    # In case of datapoints, we check if we have a kernel for a superclass registered
    if issubclass(input_type, datapoints.Datapoint):
        # Since we have already checked for an exact match above, we can start the traversal at the superclass.
        for cls in input_type.__mro__[1:]:
            if cls is datapoints.Datapoint:
                # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
                # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
                # allow kernels to be registered for datapoints.Datapoint anyway.
                break
            elif cls in registry:
                return registry[cls]

110
111
    if allow_passthrough:
        return lambda inpt, *args, **kwargs: inpt
112
113

    raise TypeError(
114
        f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, "
115
116
        f"but got {input_type} instead."
    )
117
118
119
120


# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
121
def _register_five_ten_crop_kernel_internal(dispatcher, input_type):
122
    registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
123
124
    if input_type in registry:
        raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
125
126
127
128
129
130
131
132
133
134
135

    def wrap(kernel):
        @functools.wraps(kernel)
        def wrapper(inpt, *args, **kwargs):
            output = kernel(inpt, *args, **kwargs)
            container_type = type(output)
            return container_type(type(inpt).wrap_like(inpt, o) for o in output)

        return wrapper

    def decorator(kernel):
136
        registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
137
138
139
        return kernel

    return decorator