Unverified Commit f244e27e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Dispatcher -> Functional (#7829)

parent 6ab8a96f
......@@ -49,7 +49,7 @@ my_dp
from torchvision.transforms.v2 import functional as F
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
......@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# .. note::
#
# In our call to ``register_kernel`` above we used a string
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
# ``functional="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# ``@register_kernel(dispatcher=F.hflip, ...)``.
# ``@register_kernel(functional=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
......
......@@ -163,25 +163,25 @@ def check_kernel(
_check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched))
def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
"""Checks if the dispatcher can be scripted and the scripted version can be called without error."""
def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
"""Checks if the functional can be scripted and the scripted version can be called without error."""
if not isinstance(input, datapoints.Image):
return
dispatcher_scripted = _script(dispatcher)
functional_scripted = _script(functional)
with ignore_jit_no_profile_information_warning():
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs):
def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
functional(unknown_input, *args, **kwargs)
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
output = dispatcher(input, *args, **kwargs)
output = functional(input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
spy.assert_any_call(f"{functional.__module__}.{functional.__name__}")
assert isinstance(output, type(input))
......@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar
assert output.format == input.format
if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
_check_functional_scripted_smoke(functional, input, *args, **kwargs)
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature."""
dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:]
def check_functional_kernel_signature_match(functional, *, kernel, input_type):
"""Checks if the signature of the functional matches the kernel signature."""
functional_params = list(inspect.signature(functional).parameters.values())[1:]
kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
if issubclass(input_type, datapoints.Datapoint):
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# We filter out metadata that is implicitly passed to the functional through the input datapoint, but has to be
# explicitly passed to the kernel.
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
functional_params = iter(functional_params)
for functional_param, kernel_param in zip(functional_params, kernel_params):
try:
# In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
# dispatcher parameters that have no kernel equivalent while keeping the order intact.
while dispatcher_param.name != kernel_param.name:
dispatcher_param = next(dispatcher_params)
# In general, the functional parameters are a superset of the kernel parameters. Thus, we filter out
# functional parameters that have no kernel equivalent while keeping the order intact.
while functional_param.name != kernel_param.name:
functional_param = next(functional_params)
except StopIteration:
raise AssertionError(
f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` "
f"has no corresponding parameter on the dispatcher `{dispatcher.__name__}`."
f"has no corresponding parameter on the functional `{functional.__name__}`."
) from None
if issubclass(input_type, PIL.Image.Image):
# PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
# them in the first place.
dispatcher_param._annotation = kernel_param._annotation = inspect.Parameter.empty
functional_param._annotation = kernel_param._annotation = inspect.Parameter.empty
assert dispatcher_param == kernel_param
assert functional_param == kernel_param
def _check_transform_v1_compatibility(transform, input):
......@@ -482,8 +482,8 @@ class TestResize:
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, size, make_input):
check_dispatcher(
def test_functional(self, size, make_input):
check_functional(
F.resize,
make_input(self.INPUT_SIZE),
size=size,
......@@ -502,8 +502,8 @@ class TestResize:
(F.resize_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("device", cpu_and_cuda())
......@@ -608,7 +608,7 @@ class TestResize:
interpolation=interpolation,
)
def test_dispatcher_pil_antialias_warning(self):
def test_functional_pil_antialias_warning(self):
with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"):
F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False)
......@@ -763,8 +763,8 @@ class TestHorizontalFlip:
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, make_input):
check_dispatcher(F.horizontal_flip, make_input())
def test_functional(self, make_input):
check_functional(F.horizontal_flip, make_input())
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -777,8 +777,8 @@ class TestHorizontalFlip:
(F.horizontal_flip_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"make_input",
......@@ -939,8 +939,8 @@ class TestAffine:
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, make_input):
check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
def test_functional(self, make_input):
check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -953,8 +953,8 @@ class TestAffine:
(F.affine_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"make_input",
......@@ -1228,8 +1228,8 @@ class TestVerticalFlip:
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, make_input):
check_dispatcher(F.vertical_flip, make_input())
def test_functional(self, make_input):
check_functional(F.vertical_flip, make_input())
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -1242,8 +1242,8 @@ class TestVerticalFlip:
(F.vertical_flip_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"make_input",
......@@ -1378,8 +1378,8 @@ class TestRotate:
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, make_input):
check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
def test_functional(self, make_input):
check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -1392,8 +1392,8 @@ class TestRotate:
(F.rotate_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize(
"make_input",
......@@ -1643,8 +1643,8 @@ class TestToDtype:
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale):
check_dispatcher(
def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
check_functional(
F.to_dtype,
make_input(dtype=input_dtype, device=device),
dtype=output_dtype,
......@@ -1810,8 +1810,8 @@ class TestAdjustBrightness:
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_dispatcher(self, make_input):
check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
def test_functional(self, make_input):
check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -1822,8 +1822,8 @@ class TestAdjustBrightness:
(F.adjust_brightness_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
def test_image_correctness(self, brightness_factor):
......@@ -2042,7 +2042,7 @@ class TestShapeGetters:
assert kernel(input) == F.get_num_frames(input) == num_frames
@pytest.mark.parametrize(
("dispatcher", "make_input"),
("functional", "make_input"),
[
(F.get_dimensions, make_bounding_box),
(F.get_dimensions, make_detection_mask),
......@@ -2057,22 +2057,22 @@ class TestShapeGetters:
(F.get_num_frames, make_segmentation_mask),
],
)
def test_unsupported_types(self, dispatcher, make_input):
def test_unsupported_types(self, functional, make_input):
input = make_input()
with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input)
functional(input)
class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
def test_register_kernel(self, dispatcher):
@pytest.mark.parametrize("functional", (F.resize, "resize"))
def test_register_kernel(self, functional):
class CustomDatapoint(datapoints.Datapoint):
pass
kernel_was_called = False
@F.register_kernel(dispatcher, CustomDatapoint)
@F.register_kernel(functional, CustomDatapoint)
def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called
kernel_was_called = True
......@@ -2090,10 +2090,10 @@ class TestRegisterKernel:
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_errors(self):
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
with pytest.raises(ValueError, match="Could not find functional with name"):
F.register_kernel("bad_name", datapoints.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
with pytest.raises(ValueError, match="Kernels can only be registered on functionals"):
F.register_kernel(datapoints.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
......@@ -2115,7 +2115,7 @@ class TestRegisterKernel:
class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine
KERNELS = {
torch.Tensor: F.resize_image_tensor,
......@@ -2139,7 +2139,7 @@ class TestGetKernel:
def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check the exact matching afterwards.
def resize_with_pure_kernels():
pass
......@@ -2151,7 +2151,7 @@ class TestGetKernel:
def test_builtin_datapoint_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def resize_with_pure_kernels():
......@@ -2217,8 +2217,8 @@ class TestPermuteChannels:
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_dispatcher(self, make_input):
check_dispatcher(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION)
def test_functional(self, make_input):
check_functional(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize(
("kernel", "input_type"),
......@@ -2229,8 +2229,8 @@ class TestPermuteChannels:
(F.permute_channels_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3)
......
......@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
......
......@@ -358,13 +358,13 @@ class FiveCrop(Transform):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.five_crop, inpt, self.size)
......@@ -405,13 +405,13 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
......
......@@ -30,8 +30,8 @@ class Transform(nn.Module):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict()
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True)
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......
......@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> torch.Tensor:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
......
......@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
# {dispatcher: {input_type: type_specific_kernel}}
# {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
......@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel):
return wrapper
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel):
registry[input_type] = (
......@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True)
return decorator
def _name_to_dispatcher(name):
def _name_to_functional(name):
import torchvision.transforms.v2.functional # noqa
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
) from None
......@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = {
}
def register_kernel(dispatcher, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
def register_kernel(functional, datapoint_cls):
"""Decorate a kernel to register it for a functional and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
"""
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
if isinstance(functional, str):
functional = _name_to_functional(name=functional)
elif not (
callable(dispatcher)
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
callable(functional)
and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
f"but got {dispatcher}."
f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
f"but got {functional}."
)
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
......@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(dispatcher)
def _get_kernel(functional, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(functional)
if not registry:
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
raise ValueError(f"No kernel registered for functional {functional.__name__}.")
# In case we have an exact type match, we take a shortcut.
if input_type in registry:
......@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
return lambda inpt, *args, **kwargs: inpt
raise TypeError(
f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, "
f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
f"but got {input_type} instead."
)
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(dispatcher, input_type):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(functional, input_type):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
def wrap(kernel):
@functools.wraps(kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment