Unverified Commit 154283b1 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test/test_utils.py to pytest (#3917)

parent 1b6fe682
......@@ -5,7 +5,7 @@ import sys
import tempfile
import torch
import torchvision.utils as utils
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
......@@ -18,9 +18,7 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
class Tester(unittest.TestCase):
def test_make_grid_not_inplace(self):
def test_make_grid_not_inplace():
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()
......@@ -33,7 +31,8 @@ class Tester(unittest.TestCase):
utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
def test_normalize_in_make_grid(self):
def test_normalize_in_make_grid():
t = torch.rand(5, 3, 10, 10) * 255
norm_max = torch.tensor(1.0)
norm_min = torch.tensor(0.0)
......@@ -50,22 +49,25 @@ class Tester(unittest.TestCase):
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image(self):
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
assert os.path.exists(f.name), 'The image is not present after save'
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel(self):
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
assert os.path.exists(f.name), 'The pixel image is not present after save'
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_file_object(self):
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
......@@ -75,8 +77,9 @@ class Tester(unittest.TestCase):
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
......@@ -86,7 +89,8 @@ class Tester(unittest.TestCase):
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
def test_draw_boxes(self):
def test_draw_boxes():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
......@@ -108,7 +112,8 @@ class Tester(unittest.TestCase):
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_boxes_vanilla(self):
def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
......@@ -125,15 +130,19 @@ class Tester(unittest.TestCase):
assert_equal(boxes, boxes_cp)
assert_equal(img, img_cp)
def test_draw_invalid_boxes(self):
def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
@pytest.mark.parametrize('colors', [
......@@ -218,5 +227,5 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment