Unverified Commit 74ea933c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup prototype transforms tests (#6984)

* minor cleanup of the prototype transforms tests

* refactor ImagePair

* pretty format enum
parent 4df1a85c
......@@ -2,6 +2,7 @@
import collections.abc
import dataclasses
import enum
import functools
import pathlib
from collections import defaultdict
......@@ -53,45 +54,31 @@ class ImagePair(TensorLikePair):
actual,
expected,
*,
agg_method=None,
allowed_percentage_diff=None,
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff
self.mae = mae
def compare(self) -> None:
actual, expected = self.actual, self.expected
self._compare_attributes(actual, expected)
actual, expected = self._equalize_attributes(actual, expected)
actual, expected = self._promote_for_comparison(actual, expected)
abs_diff = torch.abs(actual - expected)
if self.allowed_percentage_diff is not None:
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff:
if self.mae:
actual, expected = self._promote_for_comparison(actual, expected)
mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol:
raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
)
if self.agg_method is None:
super()._compare_values(actual, expected)
else:
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if agg_abs_diff > self.atol:
raise self._make_error_meta(
AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)
super()._compare_values(actual, expected)
def assert_close(
......@@ -142,6 +129,8 @@ def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}"
else:
return repr(obj)
......@@ -174,11 +163,13 @@ class ArgsKwargs:
yield self.kwargs
def load(self, device="cpu"):
args = tuple(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args)
kwargs = {
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg for keyword, arg in self.kwargs.items()
}
return args, kwargs
return ArgsKwargs(
*(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args),
**{
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg
for keyword, arg in self.kwargs.items()
},
)
DEFAULT_SQUARE_SPATIAL_SIZE = 15
......
......@@ -52,7 +52,7 @@ class KernelInfo(InfoBase):
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None,
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
# the reference inputs. This is usually used whenever we use a PIL kernel as reference.
# reference inputs. This is usually used whenever we use a PIL kernel as reference.
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
......@@ -73,8 +73,8 @@ class KernelInfo(InfoBase):
self.float32_vs_uint8 = float32_vs_uint8
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, agg_method=None):
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, agg_method=agg_method)
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
def cuda_vs_cpu_pixel_difference(atol=1):
......@@ -84,21 +84,21 @@ def cuda_vs_cpu_pixel_difference(atol=1):
}
def pil_reference_pixel_difference(atol=1, agg_method=None):
def pil_reference_pixel_difference(atol=1, mae=False):
return {
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs(
atol, agg_method=agg_method
atol, mae=mae
)
}
def float32_vs_uint8_pixel_difference(atol=1, agg_method=None):
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
return {
(
("TestKernels", "test_float32_vs_uint8"),
torch.float32,
"cpu",
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, agg_method=agg_method)
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
}
......@@ -359,9 +359,9 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(10, agg_method="mean"),
**pil_reference_pixel_difference(10, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(1, agg_method="mean"),
**float32_vs_uint8_pixel_difference(1, mae=True),
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
......@@ -613,7 +613,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(10, agg_method="mean"),
closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
......@@ -869,7 +869,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor,
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(1, agg_method="mean"),
closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......@@ -1054,8 +1054,8 @@ KERNEL_INFOS.extend(
float32_vs_uint8=True,
closeness_kwargs={
**cuda_vs_cpu_pixel_difference(),
**pil_reference_pixel_difference(3, agg_method="mean"),
**float32_vs_uint8_pixel_difference(3, agg_method="mean"),
**pil_reference_pixel_difference(3, mae=True),
**float32_vs_uint8_pixel_difference(3, mae=True),
},
),
KernelInfo(
......@@ -1288,7 +1288,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={
**pil_reference_pixel_difference(2, agg_method="mean"),
**pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
......@@ -1371,7 +1371,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_elastic_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={
**float32_vs_uint8_pixel_difference(6, agg_method="mean"),
**float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(),
},
),
......@@ -2028,7 +2028,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(2, agg_method="mean"),
**pil_reference_pixel_difference(2, mae=True),
**float32_vs_uint8_pixel_difference(),
},
),
......
......@@ -61,12 +61,7 @@ def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
]
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
if condition is None:
def condition(info):
return True
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
def decorator(test_fn):
parts = test_fn.__qualname__.split(".")
if len(parts) == 1:
......@@ -81,9 +76,6 @@ def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=No
argnames = ("info", "args_kwargs")
argvalues = []
for info in infos:
if not condition(info):
continue
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
......@@ -110,9 +102,8 @@ class TestKernels:
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
)
reference_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS,
[info for info in KERNEL_INFOS if info.reference_fn is not None],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
condition=lambda info: info.reference_fn is not None,
)
@ignore_jit_warning_no_profile
......@@ -131,7 +122,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
msg=parametrized_error_message(*other_args, **kwargs),
)
def _unbatch(self, batch, *, data_dims):
......@@ -188,7 +179,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, *kwargs),
msg=parametrized_error_message(*other_args, **kwargs),
)
@sample_inputs
......@@ -218,7 +209,7 @@ class TestKernels:
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, *kwargs),
msg=parametrized_error_message(*other_args, **kwargs),
)
@sample_inputs
......@@ -245,7 +236,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
msg=parametrized_error_message(*other_args, **kwargs),
)
@make_info_args_kwargs_parametrization(
......@@ -272,7 +263,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
msg=parametrized_error_message(*other_args, **kwargs),
)
......@@ -290,9 +281,8 @@ def spy_on(mocker):
class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
[info for info in DISPATCHER_INFOS if features.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: features.Image in info.kernels,
)
@ignore_jit_warning_no_profile
......@@ -341,9 +331,8 @@ class TestDispatchers:
spy.assert_called_once()
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
condition=lambda info: info.pil_kernel_info is not None,
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment