Unverified Commit 31ee79e6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_utils (#3887)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent b2f188eb
......@@ -9,6 +9,7 @@ import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
from _assert_utils import assert_equal
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
......@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
t_clone = t.clone()
utils.make_grid(t, normalize=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=True)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255
......@@ -70,8 +71,8 @@ class Tester(unittest.TestCase):
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image(self):
......@@ -96,8 +97,7 @@ class Tester(unittest.TestCase):
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Image not stored in file object')
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
......@@ -108,8 +108,7 @@ class Tester(unittest.TestCase):
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
......@@ -127,11 +126,11 @@ class Tester(unittest.TestCase):
if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
assert_equal(result, expected)
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
......@@ -145,10 +144,10 @@ class Tester(unittest.TestCase):
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
assert_equal(result, expected)
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3))
......@@ -187,7 +186,7 @@ def test_draw_segmentation_masks(colors, alpha):
# Make sure the image didn't change where there's no mask
masked_pixels = masks[0] | masks[1]
assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all()
assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
if colors is None:
colors = utils._generate_color_palette(num_masks)
......@@ -203,9 +202,8 @@ def test_draw_segmentation_masks(colors, alpha):
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha)
max_diff = (out[:, mask] - interpolated_color).abs().max()
assert max_diff <= 1
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_errors():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment