You need to sign in or sign up before continuing.
Unverified Commit 31ee79e6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_utils (#3887)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent b2f188eb
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
from io import BytesIO from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
from _assert_utils import assert_equal
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
...@@ -48,13 +49,13 @@ class Tester(unittest.TestCase): ...@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
t_clone = t.clone() t_clone = t.clone()
utils.make_grid(t, normalize=False) utils.make_grid(t, normalize=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place') assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=False) utils.make_grid(t, normalize=True, scale_each=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place') assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=True) utils.make_grid(t, normalize=True, scale_each=True)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place') assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
def test_normalize_in_make_grid(self): def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255 t = torch.rand(5, 3, 10, 10) * 255
...@@ -70,8 +71,8 @@ class Tester(unittest.TestCase): ...@@ -70,8 +71,8 @@ class Tester(unittest.TestCase):
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1') assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0') assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows') @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image(self): def test_save_image(self):
...@@ -96,8 +97,7 @@ class Tester(unittest.TestCase): ...@@ -96,8 +97,7 @@ class Tester(unittest.TestCase):
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
'Image not stored in file object')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows') @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self): def test_save_image_single_pixel_file_object(self):
...@@ -108,8 +108,7 @@ class Tester(unittest.TestCase): ...@@ -108,8 +108,7 @@ class Tester(unittest.TestCase):
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
'Pixel Image not stored in file object')
def test_draw_boxes(self): def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8) img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
...@@ -127,11 +126,11 @@ class Tester(unittest.TestCase): ...@@ -127,11 +126,11 @@ class Tester(unittest.TestCase):
if PILLOW_VERSION >= (8, 2): if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions # The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected)) assert_equal(result, expected)
# Check if modification is not in place # Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item()) assert_equal(boxes, boxes_cp)
self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) assert_equal(img, img_cp)
def test_draw_boxes_vanilla(self): def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
...@@ -145,10 +144,10 @@ class Tester(unittest.TestCase): ...@@ -145,10 +144,10 @@ class Tester(unittest.TestCase):
res.save(path) res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected)) assert_equal(result, expected)
# Check if modification is not in place # Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item()) assert_equal(boxes, boxes_cp)
self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) assert_equal(img, img_cp)
def test_draw_invalid_boxes(self): def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3)) img_tp = ((1, 1, 1), (1, 2, 3))
...@@ -187,7 +186,7 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -187,7 +186,7 @@ def test_draw_segmentation_masks(colors, alpha):
# Make sure the image didn't change where there's no mask # Make sure the image didn't change where there's no mask
masked_pixels = masks[0] | masks[1] masked_pixels = masks[0] | masks[1]
assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
if colors is None: if colors is None:
colors = utils._generate_color_palette(num_masks) colors = utils._generate_color_palette(num_masks)
...@@ -203,9 +202,8 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -203,9 +202,8 @@ def test_draw_segmentation_masks(colors, alpha):
elif alpha == 0: elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all() assert (out[:, mask] == img[:, mask]).all()
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha) interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
max_diff = (out[:, mask] - interpolated_color).abs().max() torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
assert max_diff <= 1
def test_draw_segmentation_masks_errors(): def test_draw_segmentation_masks_errors():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment