Unverified Commit 64b755a8 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix test_random_autocontrast flakyness (#3699)



* fix test

* more robust test

* flake8
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 94e19193
......@@ -370,12 +370,22 @@ class TransformsTester(unittest.TestCase):
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"):
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
self.assertTrue(
(tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
)
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
self.assertTrue(
err < tol,
......
......@@ -20,7 +20,7 @@ class Tester(TransformsTester):
def setUp(self):
self.device = "cpu"
def _test_functional_op(self, func, fn_kwargs):
def _test_functional_op(self, func, fn_kwargs, test_exact_match=True, **match_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
......@@ -28,7 +28,10 @@ class Tester(TransformsTester):
tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
......@@ -80,9 +83,9 @@ class Tester(TransformsTester):
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
self._test_functional_op(func, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
self._test_class_op(method, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
def test_random_horizontal_flip(self):
self._test_op('hflip', 'RandomHorizontalFlip')
......@@ -112,7 +115,10 @@ class Tester(TransformsTester):
)
def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')
# We check the max abs difference because on some (very rare) pixels, the actual value may be different
# between PIL and tensors due to floating approximations.
self._test_op('autocontrast', 'RandomAutocontrast', test_exact_match=False, agg_method='max',
tol=(1 + 1e-5), allowed_percentage_diff=.05)
def test_random_equalize(self):
self._test_op('equalize', 'RandomEqualize')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment