Unverified Commit e0e6f7e2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow dispatch to PIL image subclasses (#7835)

parent c1592f96
......@@ -3,6 +3,7 @@ import decimal
import inspect
import math
import re
from pathlib import Path
from unittest import mock
import numpy as np
......@@ -2126,14 +2127,8 @@ class TestGetKernel:
datapoints.Video: F.resize_video,
}
def test_unsupported_types(self):
class MyTensor(torch.Tensor):
pass
class MyPILImage(PIL.Image.Image):
pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
@pytest.mark.parametrize("input_type", [str, int, object])
def test_unsupported_types(self, input_type):
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)
......@@ -2197,6 +2192,24 @@ class TestGetKernel:
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
def test_pil_image_subclass(self):
opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
loaded_image = opened_image.convert("RGB")
# check the assumptions
assert isinstance(opened_image, PIL.Image.Image)
assert type(opened_image) is not PIL.Image.Image
assert type(loaded_image) is PIL.Image.Image
size = [17, 11]
for image in [opened_image, loaded_image]:
kernel = _get_kernel(F.resize, type(image))
output = kernel(image, size=size)
assert F.get_size(output) == size
class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1]
......
......@@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
if not registry:
raise ValueError(f"No kernel registered for functional {functional.__name__}.")
# In case we have an exact type match, we take a shortcut.
if input_type in registry:
return registry[input_type]
# In case of datapoints, we check if we have a kernel for a superclass registered
if issubclass(input_type, datapoints.Datapoint):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
for cls in input_type.__mro__[1:]:
if cls is datapoints.Datapoint:
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
elif cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment