Unverified Commit f244e27e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Dispatcher -> Functional (#7829)

parent 6ab8a96f
...@@ -49,7 +49,7 @@ my_dp ...@@ -49,7 +49,7 @@ my_dp
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) @F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs): def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!") print("Flipping!")
out = my_dp.flip(-1) out = my_dp.flip(-1)
...@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): ...@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# .. note:: # .. note::
# #
# In our call to ``register_kernel`` above we used a string # In our call to ``register_kernel`` above we used a string
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We # ``functional="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e. # could also have used the functional *itself*, i.e.
# ``@register_kernel(dispatcher=F.hflip, ...)``. # ``@register_kernel(functional=F.hflip, ...)``.
# #
# The functionals that you can be hooked into are the ones in # The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in # ``torchvision.transforms.v2.functional`` and they are documented in
......
...@@ -163,25 +163,25 @@ def check_kernel( ...@@ -163,25 +163,25 @@ def check_kernel(
_check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched)) _check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched))
def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
"""Checks if the dispatcher can be scripted and the scripted version can be called without error.""" """Checks if the functional can be scripted and the scripted version can be called without error."""
if not isinstance(input, datapoints.Image): if not isinstance(input, datapoints.Image):
return return
dispatcher_scripted = _script(dispatcher) functional_scripted = _script(functional)
with ignore_jit_no_profile_information_warning(): with ignore_jit_no_profile_information_warning():
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs): def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
unknown_input = object() unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs) functional(unknown_input, *args, **kwargs)
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
output = dispatcher(input, *args, **kwargs) output = functional(input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") spy.assert_any_call(f"{functional.__module__}.{functional.__name__}")
assert isinstance(output, type(input)) assert isinstance(output, type(input))
...@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar ...@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar
assert output.format == input.format assert output.format == input.format
if check_scripted_smoke: if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) _check_functional_scripted_smoke(functional, input, *args, **kwargs)
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def check_functional_kernel_signature_match(functional, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature.""" """Checks if the signature of the functional matches the kernel signature."""
dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:] functional_params = list(inspect.signature(functional).parameters.values())[1:]
kernel_params = list(inspect.signature(kernel).parameters.values())[1:] kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
if issubclass(input_type, datapoints.Datapoint): if issubclass(input_type, datapoints.Datapoint):
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be # We filter out metadata that is implicitly passed to the functional through the input datapoint, but has to be
# explicitly passed to the kernel. # explicitly passed to the kernel.
explicit_metadata = { explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"}, datapoints.BoundingBoxes: {"format", "canvas_size"},
} }
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
dispatcher_params = iter(dispatcher_params) functional_params = iter(functional_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): for functional_param, kernel_param in zip(functional_params, kernel_params):
try: try:
# In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out # In general, the functional parameters are a superset of the kernel parameters. Thus, we filter out
# dispatcher parameters that have no kernel equivalent while keeping the order intact. # functional parameters that have no kernel equivalent while keeping the order intact.
while dispatcher_param.name != kernel_param.name: while functional_param.name != kernel_param.name:
dispatcher_param = next(dispatcher_params) functional_param = next(functional_params)
except StopIteration: except StopIteration:
raise AssertionError( raise AssertionError(
f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` " f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` "
f"has no corresponding parameter on the dispatcher `{dispatcher.__name__}`." f"has no corresponding parameter on the functional `{functional.__name__}`."
) from None ) from None
if issubclass(input_type, PIL.Image.Image): if issubclass(input_type, PIL.Image.Image):
# PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check # PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
# them in the first place. # them in the first place.
dispatcher_param._annotation = kernel_param._annotation = inspect.Parameter.empty functional_param._annotation = kernel_param._annotation = inspect.Parameter.empty
assert dispatcher_param == kernel_param assert functional_param == kernel_param
def _check_transform_v1_compatibility(transform, input): def _check_transform_v1_compatibility(transform, input):
...@@ -482,8 +482,8 @@ class TestResize: ...@@ -482,8 +482,8 @@ class TestResize:
"make_input", "make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
) )
def test_dispatcher(self, size, make_input): def test_functional(self, size, make_input):
check_dispatcher( check_functional(
F.resize, F.resize,
make_input(self.INPUT_SIZE), make_input(self.INPUT_SIZE),
size=size, size=size,
...@@ -502,8 +502,8 @@ class TestResize: ...@@ -502,8 +502,8 @@ class TestResize:
(F.resize_video, datapoints.Video), (F.resize_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -608,7 +608,7 @@ class TestResize: ...@@ -608,7 +608,7 @@ class TestResize:
interpolation=interpolation, interpolation=interpolation,
) )
def test_dispatcher_pil_antialias_warning(self): def test_functional_pil_antialias_warning(self):
with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"): with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"):
F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False) F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False)
...@@ -763,8 +763,8 @@ class TestHorizontalFlip: ...@@ -763,8 +763,8 @@ class TestHorizontalFlip:
"make_input", "make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
) )
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.horizontal_flip, make_input()) check_functional(F.horizontal_flip, make_input())
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -777,8 +777,8 @@ class TestHorizontalFlip: ...@@ -777,8 +777,8 @@ class TestHorizontalFlip:
(F.horizontal_flip_video, datapoints.Video), (F.horizontal_flip_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -939,8 +939,8 @@ class TestAffine: ...@@ -939,8 +939,8 @@ class TestAffine:
"make_input", "make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
) )
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -953,8 +953,8 @@ class TestAffine: ...@@ -953,8 +953,8 @@ class TestAffine:
(F.affine_video, datapoints.Video), (F.affine_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1228,8 +1228,8 @@ class TestVerticalFlip: ...@@ -1228,8 +1228,8 @@ class TestVerticalFlip:
"make_input", "make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
) )
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.vertical_flip, make_input()) check_functional(F.vertical_flip, make_input())
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -1242,8 +1242,8 @@ class TestVerticalFlip: ...@@ -1242,8 +1242,8 @@ class TestVerticalFlip:
(F.vertical_flip_video, datapoints.Video), (F.vertical_flip_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1378,8 +1378,8 @@ class TestRotate: ...@@ -1378,8 +1378,8 @@ class TestRotate:
"make_input", "make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
) )
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -1392,8 +1392,8 @@ class TestRotate: ...@@ -1392,8 +1392,8 @@ class TestRotate:
(F.rotate_video, datapoints.Video), (F.rotate_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1643,8 +1643,8 @@ class TestToDtype: ...@@ -1643,8 +1643,8 @@ class TestToDtype:
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False)) @pytest.mark.parametrize("scale", (True, False))
def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale): def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
check_dispatcher( check_functional(
F.to_dtype, F.to_dtype,
make_input(dtype=input_dtype, device=device), make_input(dtype=input_dtype, device=device),
dtype=output_dtype, dtype=output_dtype,
...@@ -1810,8 +1810,8 @@ class TestAdjustBrightness: ...@@ -1810,8 +1810,8 @@ class TestAdjustBrightness:
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -1822,8 +1822,8 @@ class TestAdjustBrightness: ...@@ -1822,8 +1822,8 @@ class TestAdjustBrightness:
(F.adjust_brightness_video, datapoints.Video), (F.adjust_brightness_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
def test_image_correctness(self, brightness_factor): def test_image_correctness(self, brightness_factor):
...@@ -2042,7 +2042,7 @@ class TestShapeGetters: ...@@ -2042,7 +2042,7 @@ class TestShapeGetters:
assert kernel(input) == F.get_num_frames(input) == num_frames assert kernel(input) == F.get_num_frames(input) == num_frames
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher", "make_input"), ("functional", "make_input"),
[ [
(F.get_dimensions, make_bounding_box), (F.get_dimensions, make_bounding_box),
(F.get_dimensions, make_detection_mask), (F.get_dimensions, make_detection_mask),
...@@ -2057,22 +2057,22 @@ class TestShapeGetters: ...@@ -2057,22 +2057,22 @@ class TestShapeGetters:
(F.get_num_frames, make_segmentation_mask), (F.get_num_frames, make_segmentation_mask),
], ],
) )
def test_unsupported_types(self, dispatcher, make_input): def test_unsupported_types(self, functional, make_input):
input = make_input() input = make_input()
with pytest.raises(TypeError, match=re.escape(str(type(input)))): with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input) functional(input)
class TestRegisterKernel: class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize")) @pytest.mark.parametrize("functional", (F.resize, "resize"))
def test_register_kernel(self, dispatcher): def test_register_kernel(self, functional):
class CustomDatapoint(datapoints.Datapoint): class CustomDatapoint(datapoints.Datapoint):
pass pass
kernel_was_called = False kernel_was_called = False
@F.register_kernel(dispatcher, CustomDatapoint) @F.register_kernel(functional, CustomDatapoint)
def new_resize(dp, *args, **kwargs): def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called nonlocal kernel_was_called
kernel_was_called = True kernel_was_called = True
...@@ -2090,10 +2090,10 @@ class TestRegisterKernel: ...@@ -2090,10 +2090,10 @@ class TestRegisterKernel:
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_errors(self): def test_errors(self):
with pytest.raises(ValueError, match="Could not find dispatcher with name"): with pytest.raises(ValueError, match="Could not find functional with name"):
F.register_kernel("bad_name", datapoints.Image) F.register_kernel("bad_name", datapoints.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): with pytest.raises(ValueError, match="Kernels can only be registered on functionals"):
F.register_kernel(datapoints.Image, F.resize) F.register_kernel(datapoints.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
...@@ -2115,7 +2115,7 @@ class TestRegisterKernel: ...@@ -2115,7 +2115,7 @@ class TestRegisterKernel:
class TestGetKernel: class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine # would also be fine
KERNELS = { KERNELS = {
torch.Tensor: F.resize_image_tensor, torch.Tensor: F.resize_image_tensor,
...@@ -2139,7 +2139,7 @@ class TestGetKernel: ...@@ -2139,7 +2139,7 @@ class TestGetKernel:
def test_exact_match(self): def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check the exact matching afterwards. # here, register the kernels without wrapper, and check the exact matching afterwards.
def resize_with_pure_kernels(): def resize_with_pure_kernels():
pass pass
...@@ -2151,7 +2151,7 @@ class TestGetKernel: ...@@ -2151,7 +2151,7 @@ class TestGetKernel:
def test_builtin_datapoint_subclass(self): def test_builtin_datapoint_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched # here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass # to the kernel of the corresponding superclass
def resize_with_pure_kernels(): def resize_with_pure_kernels():
...@@ -2217,8 +2217,8 @@ class TestPermuteChannels: ...@@ -2217,8 +2217,8 @@ class TestPermuteChannels:
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION) check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_dispatcher(self, make_input): def test_functional(self, make_input):
check_dispatcher(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION) check_functional(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
...@@ -2229,8 +2229,8 @@ class TestPermuteChannels: ...@@ -2229,8 +2229,8 @@ class TestPermuteChannels:
(F.permute_channels_video, datapoints.Video), (F.permute_channels_video, datapoints.Video),
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_functional_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) check_functional_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
def reference_image_correctness(self, image, permutation): def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3) channel_images = image.split(1, dim=-3)
......
...@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform): ...@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"datapoints.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(dispatcher, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs) img_c, img_h, img_w = query_chw(flat_inputs)
......
...@@ -358,13 +358,13 @@ class FiveCrop(Transform): ...@@ -358,13 +358,13 @@ class FiveCrop(Transform):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"datapoints.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(dispatcher, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.five_crop, inpt, self.size) return self._call_kernel(F.five_crop, inpt, self.size)
...@@ -405,13 +405,13 @@ class TenCrop(Transform): ...@@ -405,13 +405,13 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"datapoints.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(dispatcher, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
......
...@@ -30,8 +30,8 @@ class Transform(nn.Module): ...@@ -30,8 +30,8 @@ class Transform(nn.Module):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict() return dict()
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs) return kernel(inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......
...@@ -203,7 +203,7 @@ def convert_format_bounding_boxes( ...@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format: Optional[BoundingBoxFormat] = None, new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value. # default error that would be thrown if `new_format` had no default value.
......
...@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool: ...@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
# {dispatcher: {input_type: type_specific_kernel}} # {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
...@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel): ...@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel):
return wrapper return wrapper
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True): def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry: if input_type in registry:
raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel): def decorator(kernel):
registry[input_type] = ( registry[input_type] = (
...@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True) ...@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True)
return decorator return decorator
def _name_to_dispatcher(name): def _name_to_functional(name):
import torchvision.transforms.v2.functional # noqa import torchvision.transforms.v2.functional # noqa
try: try:
return getattr(torchvision.transforms.v2.functional, name) return getattr(torchvision.transforms.v2.functional, name)
except AttributeError: except AttributeError:
raise ValueError( raise ValueError(
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional." f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
) from None ) from None
...@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = { ...@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = {
} }
def register_kernel(dispatcher, datapoint_cls): def register_kernel(functional, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. """Decorate a kernel to register it for a functional and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details. details.
""" """
if isinstance(dispatcher, str): if isinstance(functional, str):
dispatcher = _name_to_dispatcher(name=dispatcher) functional = _name_to_functional(name=functional)
elif not ( elif not (
callable(dispatcher) callable(functional)
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional") and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
): ):
raise ValueError( raise ValueError(
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, " f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
f"but got {dispatcher}." f"but got {functional}."
) )
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)): if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
...@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls): ...@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES: if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}") raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): def _get_kernel(functional, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(dispatcher) registry = _KERNEL_REGISTRY.get(functional)
if not registry: if not registry:
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") raise ValueError(f"No kernel registered for functional {functional.__name__}.")
# In case we have an exact type match, we take a shortcut. # In case we have an exact type match, we take a shortcut.
if input_type in registry: if input_type in registry:
...@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): ...@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
return lambda inpt, *args, **kwargs: inpt return lambda inpt, *args, **kwargs: inpt
raise TypeError( raise TypeError(
f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
f"but got {input_type} instead." f"but got {input_type} instead."
) )
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool # We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(dispatcher, input_type): def _register_five_ten_crop_kernel_internal(functional, input_type):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry: if input_type in registry:
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
def wrap(kernel): def wrap(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment