Unverified Commit 4a99bae8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add dispatch tests for prototype transform dispatchers (#6631)

parent 0e006a9f
import collections.abc
import dataclasses
from collections import defaultdict
from typing import Callable, Dict, List, Sequence, Type
from typing import Callable, Dict, List, Optional, Sequence, Type
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
from prototype_common_utils import BoundingBoxLoader
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
from torchvision.prototype import features
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
@dataclasses.dataclass
class PILKernelInfo:
kernel: Callable
kernel_name: str = dataclasses.field(default=None)
def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
......@@ -28,21 +40,35 @@ def skip_integer_size_jit(name="size"):
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
def sample_inputs(self, *types):
for type in types or self.kernels.keys():
if type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
if feature_type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}")
sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
if not filter_metadata:
yield from sample_inputs
else:
for args_kwargs in sample_inputs:
for attribute in feature_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
yield args_kwargs
def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
......@@ -54,6 +80,31 @@ class DispatcherInfo:
pytest.skip(skip.reason)
def fill_sequence_needs_broadcast(args_kwargs, device):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False
if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False
return image_loader.num_channels > 1
skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip(
"test_dispatch_pil",
condition=fill_sequence_needs_broadcast,
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
)
skip_dispatch_feature = Skip(
"test_dispatch_feature",
reason="Dispatcher doesn't support arbitrary feature dispatch.",
)
DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
......@@ -62,6 +113,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.horizontal_flip_bounding_box,
features.Mask: F.horizontal_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
),
DispatcherInfo(
F.resize,
......@@ -70,6 +122,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask,
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
skips=[
skip_integer_size_jit(),
],
......@@ -81,7 +134,11 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask,
},
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
],
),
DispatcherInfo(
F.vertical_flip,
......@@ -90,6 +147,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.vertical_flip_bounding_box,
features.Mask: F.vertical_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
),
DispatcherInfo(
F.rotate,
......@@ -98,6 +156,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.rotate_bounding_box,
features.Mask: F.rotate_mask,
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
),
DispatcherInfo(
F.crop,
......@@ -106,6 +165,17 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.crop_bounding_box,
features.Mask: F.crop_mask,
},
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
skips=[
Skip(
"test_dispatch_feature",
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
reason=(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
),
DispatcherInfo(
F.resized_crop,
......@@ -114,6 +184,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.resized_crop_bounding_box,
features.Mask: F.resized_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
),
DispatcherInfo(
F.pad,
......@@ -122,6 +193,10 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
),
DispatcherInfo(
F.perspective,
......@@ -130,6 +205,10 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
),
DispatcherInfo(
F.elastic,
......@@ -138,6 +217,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
),
DispatcherInfo(
F.center_crop,
......@@ -146,6 +226,7 @@ DISPATCHER_INFOS = [
features.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
skips=[
skip_integer_size_jit("output_size"),
],
......@@ -155,6 +236,7 @@ DISPATCHER_INFOS = [
kernels={
features.Image: F.gaussian_blur_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
......@@ -165,80 +247,97 @@ DISPATCHER_INFOS = [
kernels={
features.Image: F.equalize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
),
DispatcherInfo(
F.invert,
kernels={
features.Image: F.invert_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"),
),
DispatcherInfo(
F.posterize,
kernels={
features.Image: F.posterize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"),
),
DispatcherInfo(
F.solarize,
kernels={
features.Image: F.solarize_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"),
),
DispatcherInfo(
F.autocontrast,
kernels={
features.Image: F.autocontrast_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
),
DispatcherInfo(
F.adjust_sharpness,
kernels={
features.Image: F.adjust_sharpness_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
),
DispatcherInfo(
F.erase,
kernels={
features.Image: F.erase_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
skips=[
skip_dispatch_feature,
],
),
DispatcherInfo(
F.adjust_brightness,
kernels={
features.Image: F.adjust_brightness_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
),
DispatcherInfo(
F.adjust_contrast,
kernels={
features.Image: F.adjust_contrast_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
),
DispatcherInfo(
F.adjust_gamma,
kernels={
features.Image: F.adjust_gamma_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
),
DispatcherInfo(
F.adjust_hue,
kernels={
features.Image: F.adjust_hue_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
),
DispatcherInfo(
F.adjust_saturation,
kernels={
features.Image: F.adjust_saturation_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
),
DispatcherInfo(
F.five_crop,
kernels={
features.Image: F.five_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
skips=[
skip_integer_size_jit(),
skip_dispatch_feature,
],
),
DispatcherInfo(
......@@ -246,8 +345,10 @@ DISPATCHER_INFOS = [
kernels={
features.Image: F.ten_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
skips=[
skip_integer_size_jit(),
skip_dispatch_feature,
],
),
DispatcherInfo(
......@@ -255,5 +356,8 @@ DISPATCHER_INFOS = [
kernels={
features.Image: F.normalize_image_tensor,
},
skips=[
skip_dispatch_feature,
],
),
]
......@@ -33,7 +33,7 @@ class KernelInfo:
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name: Optional[str] = None
kernel_name: str = dataclasses.field(default=None)
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
......
......@@ -174,6 +174,18 @@ class TestKernels:
assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs)
@pytest.fixture
def spy_on(mocker):
def make_spy(fn, *, module=None, name=None):
# TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
module = module or fn.__module__
name = name or fn.__name__
spy = mocker.patch(f"{module}.{name}", wraps=fn)
return spy
return make_spy
class TestDispatchers:
@pytest.mark.parametrize(
("info", "args_kwargs"),
......@@ -211,6 +223,69 @@ class TestDispatchers:
def test_scriptable(self, dispatcher):
script(dispatcher)
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels
],
)
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature)
kernel_info = info.kernel_infos[features.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.kernel_name)
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once()
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs(features.Image))
if features.Image in info.kernels and info.pil_kernel_info is not None
],
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
if image_feature.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_feature)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.kernel_name)
info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once()
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}")
for info in DISPATCHER_INFOS
for idx, args_kwargs in enumerate(info.sample_inputs())
],
)
def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load()
method = getattr(feature, info.method_name)
feature_type = type(feature)
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{info.method_name}")
info.dispatcher(feature, *other_args, **kwargs)
spy.assert_called_once()
@pytest.mark.parametrize(
("alias", "target"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment