Unverified Commit 9ebf10af authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow register_kernel() to take dispatcher name as input (#7796)

parent f3c89cc6
...@@ -2181,3 +2181,36 @@ class TestShapeGetters: ...@@ -2181,3 +2181,36 @@ class TestShapeGetters:
with pytest.raises(TypeError, match=re.escape(str(type(input)))): with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input) dispatcher(input)
class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
def test_register_kernel(self, dispatcher):
class CustomDatapoint(datapoints.Datapoint):
pass
kernel_was_called = False
@F.register_kernel(dispatcher, CustomDatapoint)
def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called
kernel_was_called = True
return dp
t = transforms.Resize(size=(224, 224), antialias=True)
my_dp = CustomDatapoint(torch.rand(3, 10, 10))
out = t(my_dp)
assert out is my_dp
assert kernel_was_called
# Sanity check to make sure we didn't override the kernel of other types
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_bad_disaptcher_name(self):
class CustomDatapoint(datapoints.Datapoint):
pass
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", CustomDatapoint)
...@@ -37,7 +37,18 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=Tr ...@@ -37,7 +37,18 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=Tr
return decorator return decorator
def _name_to_dispatcher(name):
import torchvision.transforms.v2.functional # noqa
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None
def register_kernel(dispatcher, datapoint_cls): def register_kernel(dispatcher, datapoint_cls):
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment