Unverified Commit 47382673 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Simplify the assertExpected method (#2965)

* Simplify the ACCEPT=True logic in assertExpected().

* Separate the expected filename estimation from assertExpected
parent 32e57007
...@@ -88,24 +88,7 @@ def is_iterable(obj): ...@@ -88,24 +88,7 @@ def is_iterable(obj):
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
precision = 1e-5 precision = 1e-5
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None): def _get_expected_file(self, subname=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
def remove_prefix_suffix(text, prefix, suffix): def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix): if text.startswith(prefix):
text = text[len(prefix):] text = text[len(prefix):]
...@@ -128,33 +111,41 @@ class TestCase(unittest.TestCase): ...@@ -128,33 +111,41 @@ class TestCase(unittest.TestCase):
subname_output = " ({})".format(subname) subname_output = " ({})".format(subname)
expected_file += "_expect.pkl" expected_file += "_expect.pkl"
def accept_output(update_type): if not ACCEPT and not os.path.exists(expected_file):
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output)) raise RuntimeError(
("No expect file exists for {}{}; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id))
return expected_file
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
expected_file = self._get_expected_file(subname, strip_suffix)
if ACCEPT:
print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output))
torch.save(output, expected_file) torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file) binary_size = os.path.getsize(expected_file)
self.assertTrue(binary_size <= MAX_PICKLE_SIZE) self.assertTrue(binary_size <= MAX_PICKLE_SIZE)
try:
expected = torch.load(expected_file)
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
accept_output("output")
return
else:
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))
if ACCEPT:
try:
self.assertEqual(output, expected, prec=prec)
except Exception:
accept_output("updated output")
else: else:
expected = torch.load(expected_file)
self.assertEqual(output, expected, prec=prec) self.assertEqual(output, expected, prec=prec)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False): def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment