Commit aadbed6f authored by F-G Fernandez's avatar F-G Fernandez Committed by Francisco Massa
Browse files

test: Updated asserts in test_utils (#1499)

* test: Updated asserts in test_utils

Updated all raw asserts to corresponding unittest.TestCase.assert. See #1483

* style: Fixed lint check
parent 5eee0117
......@@ -16,13 +16,13 @@ class Tester(unittest.TestCase):
t_clone = t.clone()
utils.make_grid(t, normalize=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=True)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255
......@@ -38,22 +38,22 @@ class Tester(unittest.TestCase):
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
assert torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1'
assert torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0'
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The image is not present after save'
self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_single_pixel(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save'
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_file_object(self):
......@@ -64,7 +64,8 @@ class Tester(unittest.TestCase):
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Image not stored in file object')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
......@@ -75,7 +76,8 @@ class Tester(unittest.TestCase):
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment