Unverified Commit 3f556e20 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in common_utils.py (#3873)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 0bb9b914
...@@ -20,6 +20,8 @@ from _utils_internal import get_relative_path ...@@ -20,6 +20,8 @@ from _utils_internal import get_relative_path
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from _assert_utils import assert_equal
IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
...@@ -139,7 +141,8 @@ class TestCase(unittest.TestCase): ...@@ -139,7 +141,8 @@ class TestCase(unittest.TestCase):
raise RuntimeError("The output for {}, is larger than 50kb".format(filename)) raise RuntimeError("The output for {}, is larger than 50kb".format(filename))
else: else:
expected = torch.load(expected_file) expected = torch.load(expected_file)
self.assertEqual(output, expected, prec=prec) rtol = atol = prec or self.precision
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False): def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
""" """
...@@ -345,7 +348,7 @@ class TransformsTester(unittest.TestCase): ...@@ -345,7 +348,7 @@ class TransformsTester(unittest.TestCase):
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))) pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None: if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg) assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None): allowed_percentage_diff=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment