Unverified Commit 74ea933c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup prototype transforms tests (#6984)

* minor cleanup of the prototype transforms tests

* refactor ImagePair

* pretty format enum
parent 4df1a85c
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import collections.abc import collections.abc
import dataclasses import dataclasses
import enum
import functools import functools
import pathlib import pathlib
from collections import defaultdict from collections import defaultdict
...@@ -53,45 +54,31 @@ class ImagePair(TensorLikePair): ...@@ -53,45 +54,31 @@ class ImagePair(TensorLikePair):
actual, actual,
expected, expected,
*, *,
agg_method=None, mae=False,
allowed_percentage_diff=None,
**other_parameters, **other_parameters,
): ):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]] actual, expected = [to_image_tensor(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters) super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method self.mae = mae
self.allowed_percentage_diff = allowed_percentage_diff
def compare(self) -> None: def compare(self) -> None:
actual, expected = self.actual, self.expected actual, expected = self.actual, self.expected
self._compare_attributes(actual, expected) self._compare_attributes(actual, expected)
actual, expected = self._equalize_attributes(actual, expected) actual, expected = self._equalize_attributes(actual, expected)
actual, expected = self._promote_for_comparison(actual, expected)
abs_diff = torch.abs(actual - expected)
if self.allowed_percentage_diff is not None: if self.mae:
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean())) actual, expected = self._promote_for_comparison(actual, expected)
if percentage_diff > self.allowed_percentage_diff: mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol:
raise self._make_error_meta( raise self._make_error_meta(
AssertionError, AssertionError,
f"{percentage_diff:.1%} elements differ, " f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
f"but only {self.allowed_percentage_diff:.1%} is allowed",
) )
if self.agg_method is None:
super()._compare_values(actual, expected)
else: else:
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64))) super()._compare_values(actual, expected)
if agg_abs_diff > self.atol:
raise self._make_error_meta(
AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)
def assert_close( def assert_close(
...@@ -142,6 +129,8 @@ def parametrized_error_message(*args, **kwargs): ...@@ -142,6 +129,8 @@ def parametrized_error_message(*args, **kwargs):
def to_str(obj): def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10: if isinstance(obj, torch.Tensor) and obj.numel() > 10:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})" return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}"
else: else:
return repr(obj) return repr(obj)
...@@ -174,11 +163,13 @@ class ArgsKwargs: ...@@ -174,11 +163,13 @@ class ArgsKwargs:
yield self.kwargs yield self.kwargs
def load(self, device="cpu"): def load(self, device="cpu"):
args = tuple(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args) return ArgsKwargs(
kwargs = { *(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args),
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg for keyword, arg in self.kwargs.items() **{
} keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg
return args, kwargs for keyword, arg in self.kwargs.items()
},
)
DEFAULT_SQUARE_SPATIAL_SIZE = 15 DEFAULT_SQUARE_SPATIAL_SIZE = 15
......
...@@ -52,7 +52,7 @@ class KernelInfo(InfoBase): ...@@ -52,7 +52,7 @@ class KernelInfo(InfoBase):
# values to be tested. If not specified, `sample_inputs_fn` will be used. # values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None, reference_inputs_fn=None,
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
# the reference inputs. This is usually used whenever we use a PIL kernel as reference. # reference inputs. This is usually used whenever we use a PIL kernel as reference.
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype. # dtype.
...@@ -73,8 +73,8 @@ class KernelInfo(InfoBase): ...@@ -73,8 +73,8 @@ class KernelInfo(InfoBase):
self.float32_vs_uint8 = float32_vs_uint8 self.float32_vs_uint8 = float32_vs_uint8
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, agg_method=None): def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, agg_method=agg_method) return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
def cuda_vs_cpu_pixel_difference(atol=1): def cuda_vs_cpu_pixel_difference(atol=1):
...@@ -84,21 +84,21 @@ def cuda_vs_cpu_pixel_difference(atol=1): ...@@ -84,21 +84,21 @@ def cuda_vs_cpu_pixel_difference(atol=1):
} }
def pil_reference_pixel_difference(atol=1, agg_method=None): def pil_reference_pixel_difference(atol=1, mae=False):
return { return {
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs( (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs(
atol, agg_method=agg_method atol, mae=mae
) )
} }
def float32_vs_uint8_pixel_difference(atol=1, agg_method=None): def float32_vs_uint8_pixel_difference(atol=1, mae=False):
return { return {
( (
("TestKernels", "test_float32_vs_uint8"), ("TestKernels", "test_float32_vs_uint8"),
torch.float32, torch.float32,
"cpu", "cpu",
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, agg_method=agg_method) ): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
} }
...@@ -359,9 +359,9 @@ KERNEL_INFOS.extend( ...@@ -359,9 +359,9 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
**pil_reference_pixel_difference(10, agg_method="mean"), **pil_reference_pixel_difference(10, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(1, agg_method="mean"), **float32_vs_uint8_pixel_difference(1, mae=True),
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
...@@ -613,7 +613,7 @@ KERNEL_INFOS.extend( ...@@ -613,7 +613,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.affine_image_pil), reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor, reference_inputs_fn=reference_inputs_affine_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10, agg_method="mean"), closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
...@@ -869,7 +869,7 @@ KERNEL_INFOS.extend( ...@@ -869,7 +869,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.rotate_image_pil), reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor, reference_inputs_fn=reference_inputs_rotate_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(1, agg_method="mean"), closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
...@@ -1054,8 +1054,8 @@ KERNEL_INFOS.extend( ...@@ -1054,8 +1054,8 @@ KERNEL_INFOS.extend(
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**pil_reference_pixel_difference(3, agg_method="mean"), **pil_reference_pixel_difference(3, mae=True),
**float32_vs_uint8_pixel_difference(3, agg_method="mean"), **float32_vs_uint8_pixel_difference(3, mae=True),
}, },
), ),
KernelInfo( KernelInfo(
...@@ -1288,7 +1288,7 @@ KERNEL_INFOS.extend( ...@@ -1288,7 +1288,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={ closeness_kwargs={
**pil_reference_pixel_difference(2, agg_method="mean"), **pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(), **float32_vs_uint8_pixel_difference(),
}, },
...@@ -1371,7 +1371,7 @@ KERNEL_INFOS.extend( ...@@ -1371,7 +1371,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={ closeness_kwargs={
**float32_vs_uint8_pixel_difference(6, agg_method="mean"), **float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
}, },
), ),
...@@ -2028,7 +2028,7 @@ KERNEL_INFOS.extend( ...@@ -2028,7 +2028,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
**pil_reference_pixel_difference(2, agg_method="mean"), **pil_reference_pixel_difference(2, mae=True),
**float32_vs_uint8_pixel_difference(), **float32_vs_uint8_pixel_difference(),
}, },
), ),
......
...@@ -61,12 +61,7 @@ def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None): ...@@ -61,12 +61,7 @@ def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
] ]
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None): def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
if condition is None:
def condition(info):
return True
def decorator(test_fn): def decorator(test_fn):
parts = test_fn.__qualname__.split(".") parts = test_fn.__qualname__.split(".")
if len(parts) == 1: if len(parts) == 1:
...@@ -81,9 +76,6 @@ def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=No ...@@ -81,9 +76,6 @@ def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=No
argnames = ("info", "args_kwargs") argnames = ("info", "args_kwargs")
argvalues = [] argvalues = []
for info in infos: for info in infos:
if not condition(info):
continue
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id)) argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
return pytest.mark.parametrize(argnames, argvalues)(test_fn) return pytest.mark.parametrize(argnames, argvalues)(test_fn)
...@@ -110,9 +102,8 @@ class TestKernels: ...@@ -110,9 +102,8 @@ class TestKernels:
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(), args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
) )
reference_inputs = make_info_args_kwargs_parametrization( reference_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS, [info for info in KERNEL_INFOS if info.reference_fn is not None],
args_kwargs_fn=lambda info: info.reference_inputs_fn(), args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
) )
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
...@@ -131,7 +122,7 @@ class TestKernels: ...@@ -131,7 +122,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs), msg=parametrized_error_message(*other_args, **kwargs),
) )
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
...@@ -188,7 +179,7 @@ class TestKernels: ...@@ -188,7 +179,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device), **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, *kwargs), msg=parametrized_error_message(*other_args, **kwargs),
) )
@sample_inputs @sample_inputs
...@@ -218,7 +209,7 @@ class TestKernels: ...@@ -218,7 +209,7 @@ class TestKernels:
output_cpu, output_cpu,
check_device=False, check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device), **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, *kwargs), msg=parametrized_error_message(*other_args, **kwargs),
) )
@sample_inputs @sample_inputs
...@@ -245,7 +236,7 @@ class TestKernels: ...@@ -245,7 +236,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs), msg=parametrized_error_message(*other_args, **kwargs),
) )
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
...@@ -272,7 +263,7 @@ class TestKernels: ...@@ -272,7 +263,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device), **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs), msg=parametrized_error_message(*other_args, **kwargs),
) )
...@@ -290,9 +281,8 @@ def spy_on(mocker): ...@@ -290,9 +281,8 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization( image_sample_inputs = make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, [info for info in DISPATCHER_INFOS if features.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels,
) )
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
...@@ -341,9 +331,8 @@ class TestDispatchers: ...@@ -341,9 +331,8 @@ class TestDispatchers:
spy.assert_called_once() spy.assert_called_once()
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: info.pil_kernel_info is not None,
) )
def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment