Unverified Commit 65769ab7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix prototype transforms tests with set agg_method (#6934)

* fix prototype transforms tests with set agg_method

* use individual tolerances

* refactor PIL reference test

* increase tolerance for elastic_mask

* fix autocontrast tolerances

* increase tolerance for RandomAutocontrast
parent d72e9064
...@@ -12,17 +12,9 @@ import torch ...@@ -12,17 +12,9 @@ import torch
import torch.testing import torch.testing
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import ( from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
assert_equal as _assert_equal,
BooleanPair,
ErrorMeta,
NonePair,
NumberPair,
TensorLikePair,
UnsupportedInputs,
)
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.prototype.transforms.functional import to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = [ __all__ = [
...@@ -54,7 +46,7 @@ __all__ = [ ...@@ -54,7 +46,7 @@ __all__ = [
] ]
class PILImagePair(TensorLikePair): class ImagePair(TensorLikePair):
def __init__( def __init__(
self, self,
actual, actual,
...@@ -64,44 +56,13 @@ class PILImagePair(TensorLikePair): ...@@ -64,44 +56,13 @@ class PILImagePair(TensorLikePair):
allowed_percentage_diff=None, allowed_percentage_diff=None,
**other_parameters, **other_parameters,
): ):
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
raise UnsupportedInputs() actual, expected = [to_image_tensor(input) for input in [actual, expected]]
# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
other_parameters["check_device"] = False
super().__init__(actual, expected, **other_parameters) super().__init__(actual, expected, **other_parameters)
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
self.allowed_percentage_diff = allowed_percentage_diff self.allowed_percentage_diff = allowed_percentage_diff
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
# shape check that will fail if we don't broadcast before.
try:
actual, expected = torch.broadcast_tensors(actual, expected)
except RuntimeError:
raise ErrorMeta(
AssertionError,
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
id=id,
) from None
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)
return super()._equalize_attributes(actual, expected)
def compare(self) -> None: def compare(self) -> None:
actual, expected = self.actual, self.expected actual, expected = self.actual, self.expected
...@@ -111,16 +72,24 @@ class PILImagePair(TensorLikePair): ...@@ -111,16 +72,24 @@ class PILImagePair(TensorLikePair):
abs_diff = torch.abs(actual - expected) abs_diff = torch.abs(actual - expected)
if self.allowed_percentage_diff is not None: if self.allowed_percentage_diff is not None:
percentage_diff = (abs_diff != 0).to(torch.float).mean() percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
if percentage_diff > self.allowed_percentage_diff: if percentage_diff > self.allowed_percentage_diff:
self._make_error_meta(AssertionError, "percentage mismatch") raise self._make_error_meta(
AssertionError,
f"{percentage_diff:.1%} elements differ, "
f"but only {self.allowed_percentage_diff:.1%} is allowed",
)
if self.agg_method is None: if self.agg_method is None:
super()._compare_values(actual, expected) super()._compare_values(actual, expected)
else: else:
err = self.agg_method(abs_diff.to(torch.float64)) agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
if err > self.atol: if agg_abs_diff > self.atol:
self._make_error_meta(AssertionError, "aggregated mismatch") raise self._make_error_meta(
AssertionError,
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
f"but only {self.atol} is allowed.",
)
def assert_close( def assert_close(
...@@ -148,7 +117,7 @@ def assert_close( ...@@ -148,7 +117,7 @@ def assert_close(
NonePair, NonePair,
BooleanPair, BooleanPair,
NumberPair, NumberPair,
PILImagePair, ImagePair,
TensorLikePair, TensorLikePair,
), ),
allow_subclasses=allow_subclasses, allow_subclasses=allow_subclasses,
...@@ -167,6 +136,32 @@ def assert_close( ...@@ -167,6 +136,32 @@ def assert_close(
assert_equal = functools.partial(assert_close, rtol=0, atol=0) assert_equal = functools.partial(assert_close, rtol=0, atol=0)
def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
else:
return repr(obj)
if args or kwargs:
postfix = "\n".join(
[
"",
"Failure happened for the following parameters:",
"",
*[to_str(arg) for arg in args],
*[f"{name}={to_str(kwarg)}" for name, kwarg in kwargs.items()],
]
)
else:
postfix = ""
def wrapper(msg):
return msg + postfix
return wrapper
class ArgsKwargs: class ArgsKwargs:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.args = args self.args = args
...@@ -656,6 +651,13 @@ class InfoBase: ...@@ -656,6 +651,13 @@ class InfoBase:
] ]
def get_closeness_kwargs(self, test_id, *, dtype, device): def get_closeness_kwargs(self, test_id, *, dtype, device):
if not (isinstance(test_id, tuple) and len(test_id) == 2):
msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
if callable(test_id):
msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
else:
msg += f", but got {test_id} instead."
raise pytest.UsageError(msg)
if isinstance(device, torch.device): if isinstance(device, torch.device):
device = device.type device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict()) return self.closeness_kwargs.get((test_id, dtype, device), dict())
This diff is collapsed.
...@@ -244,16 +244,19 @@ CONSISTENCY_CONFIGS = [ ...@@ -244,16 +244,19 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, threshold=0.99), ArgsKwargs(p=1, threshold=0.99),
], ],
), ),
ConsistencyConfig( *[
prototype_transforms.RandomAutocontrast, ConsistencyConfig(
legacy_transforms.RandomAutocontrast, prototype_transforms.RandomAutocontrast,
[ legacy_transforms.RandomAutocontrast,
ArgsKwargs(p=0), [
ArgsKwargs(p=1), ArgsKwargs(p=0),
], ArgsKwargs(p=1),
# Use default tolerances of `torch.testing.assert_close` ],
closeness_kwargs=dict(rtol=None, atol=None), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
), closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness, prototype_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness, legacy_transforms.RandomAdjustSharpness,
...@@ -1007,7 +1010,7 @@ class TestRefSegTransforms: ...@@ -1007,7 +1010,7 @@ class TestRefSegTransforms:
dp = (conv_fn(feature_image), feature_mask) dp = (conv_fn(feature_image), feature_mask)
dp_ref = ( dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask), to_image_pil(feature_mask),
) )
...@@ -1021,12 +1024,16 @@ class TestRefSegTransforms: ...@@ -1021,12 +1024,16 @@ class TestRefSegTransforms:
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
self.set_seed() self.set_seed()
output = t(dp) actual = actual_image, actual_mask = t(dp)
self.set_seed() self.set_seed()
expected_output = t_ref(*dp_ref) expected_image, expected_mask = t_ref(*dp_ref)
if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
expected_image = legacy_F.pil_to_tensor(expected_image)
expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
expected = (expected_image, expected_mask)
assert_equal(output, expected_output) assert_equal(actual, expected)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"), ("t_ref", "t", "data_kwargs"),
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
import torch import torch
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -22,6 +22,10 @@ from torchvision.prototype.transforms.functional._meta import convert_format_bou ...@@ -22,6 +22,10 @@ from torchvision.prototype.transforms.functional._meta import convert_format_bou
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}
@cache @cache
def script(fn): def script(fn):
try: try:
...@@ -127,6 +131,7 @@ class TestKernels: ...@@ -127,6 +131,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
...@@ -183,6 +188,7 @@ class TestKernels: ...@@ -183,6 +188,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device), **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
@sample_inputs @sample_inputs
...@@ -212,6 +218,7 @@ class TestKernels: ...@@ -212,6 +218,7 @@ class TestKernels:
output_cpu, output_cpu,
check_device=False, check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device), **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
@sample_inputs @sample_inputs
...@@ -237,8 +244,35 @@ class TestKernels: ...@@ -237,8 +244,35 @@ class TestKernels:
assert_close( assert_close(
actual, actual,
expected, expected,
check_dtype=False,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
)
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.float32_vs_uint8],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_float32_vs_uint8(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
if input.dtype != torch.uint8:
pytest.skip(f"Input dtype is {input.dtype}.")
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32),
*adapted_other_args,
**adapted_kwargs,
)
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, *kwargs),
) )
...@@ -421,12 +455,12 @@ def test_alias(alias, target): ...@@ -421,12 +455,12 @@ def test_alias(alias, target):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("info", "args_kwargs"), ("info", "args_kwargs"),
make_info_args_kwargs_params( make_info_args_kwargs_params(
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(), args_kwargs_fn=lambda info: info.sample_inputs_fn(),
), ),
) )
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device): def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment