Unverified Commit e0e6f7e2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow dispatch to PIL image subclasses (#7835)

parent c1592f96
...@@ -3,6 +3,7 @@ import decimal ...@@ -3,6 +3,7 @@ import decimal
import inspect import inspect
import math import math
import re import re
from pathlib import Path
from unittest import mock from unittest import mock
import numpy as np import numpy as np
...@@ -2126,16 +2127,10 @@ class TestGetKernel: ...@@ -2126,16 +2127,10 @@ class TestGetKernel:
datapoints.Video: F.resize_video, datapoints.Video: F.resize_video,
} }
def test_unsupported_types(self): @pytest.mark.parametrize("input_type", [str, int, object])
class MyTensor(torch.Tensor): def test_unsupported_types(self, input_type):
pass with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)
class MyPILImage(PIL.Image.Image):
pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)
def test_exact_match(self): def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
...@@ -2197,6 +2192,24 @@ class TestGetKernel: ...@@ -2197,6 +2192,24 @@ class TestGetKernel:
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
def test_pil_image_subclass(self):
opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
loaded_image = opened_image.convert("RGB")
# check the assumptions
assert isinstance(opened_image, PIL.Image.Image)
assert type(opened_image) is not PIL.Image.Image
assert type(loaded_image) is PIL.Image.Image
size = [17, 11]
for image in [opened_image, loaded_image]:
kernel = _get_kernel(F.resize, type(image))
output = kernel(image, size=size)
assert F.get_size(output) == size
class TestPermuteChannels: class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1] _DEFAULT_PERMUTATION = [2, 0, 1]
......
...@@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False): ...@@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
if not registry: if not registry:
raise ValueError(f"No kernel registered for functional {functional.__name__}.") raise ValueError(f"No kernel registered for functional {functional.__name__}.")
# In case we have an exact type match, we take a shortcut. for cls in input_type.__mro__:
if input_type in registry: if cls in registry:
return registry[input_type] return registry[cls]
elif cls is datapoints.Datapoint:
# In case of datapoints, we check if we have a kernel for a superclass registered # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
if issubclass(input_type, datapoints.Datapoint): # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# Since we have already checked for an exact match above, we can start the traversal at the superclass. # allow kernels to be registered for datapoints.Datapoint anyway.
for cls in input_type.__mro__[1:]: break
if cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
if allow_passthrough: if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt return lambda inpt, *args, **kwargs: inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment