You need to sign in or sign up before continuing.
Unverified Commit 154283b1 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test/test_utils.py to pytest (#3917)

parent 1b6fe682
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
import tempfile import tempfile
import torch import torch
import torchvision.utils as utils import torchvision.utils as utils
import unittest
from io import BytesIO from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
...@@ -18,122 +18,131 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], ...@@ -18,122 +18,131 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
class Tester(unittest.TestCase): def test_make_grid_not_inplace():
t = torch.rand(5, 3, 10, 10)
def test_make_grid_not_inplace(self): t_clone = t.clone()
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone() utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
def test_normalize_in_make_grid():
def test_normalize_in_make_grid(self): t = torch.rand(5, 3, 10, 10) * 255
t = torch.rand(5, 3, 10, 10) * 255 norm_max = torch.tensor(1.0)
norm_max = torch.tensor(1.0) norm_min = torch.tensor(0.0)
norm_min = torch.tensor(0.0)
grid = utils.make_grid(t, normalize=True)
grid = utils.make_grid(t, normalize=True) grid_max = torch.max(grid)
grid_max = torch.max(grid) grid_min = torch.min(grid)
grid_min = torch.min(grid)
# Rounding the result to one decimal for comparison
# Rounding the result to one decimal for comparison n_digits = 1
n_digits = 1 rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1') assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image(self): def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64) t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name) utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The image is not present after save') assert os.path.exists(f.name), 'The image is not present after save'
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save'
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
def test_draw_boxes():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
if PILLOW_VERSION >= (8, 2):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected) assert_equal(result, expected)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_invalid_boxes(self): # Check if modification is not in place
img_tp = ((1, 1, 1), (1, 2, 3)) assert_equal(boxes, boxes_cp)
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) assert_equal(img, img_cp)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) def test_draw_boxes_vanilla():
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes) img_cp = img.clone()
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes) boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check if modification is not in place
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
@pytest.mark.parametrize('colors', [ @pytest.mark.parametrize('colors', [
...@@ -218,5 +227,5 @@ def test_draw_segmentation_masks_errors(): ...@@ -218,5 +227,5 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment