Unverified Commit 1cbcb2b1 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test/test_image.py to pytest (#3930)

parent 4c563846
......@@ -2,7 +2,6 @@ import glob
import io
import os
import sys
import unittest
from pathlib import Path
import pytest
......@@ -54,17 +53,23 @@ def normalize_dimensions(img_pil):
return img_pil
class ImageTester(unittest.TestCase):
def test_decode_jpeg(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for pil_mode, mode in conversion:
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
])
@pytest.mark.parametrize('pil_mode, mode', [
(None, ImageReadMode.UNCHANGED),
("L", ImageReadMode.GRAY),
("RGB", ImageReadMode.RGB),
])
def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
if is_cmyk:
# libjpeg does not support the conversion
continue
pytest.xfail("Decoding a CMYK jpeg isn't supported")
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))
if is_cmyk:
......@@ -78,38 +83,54 @@ class ImageTester(unittest.TestCase):
# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
self.assertTrue(abs_mean_diff < 2)
assert abs_mean_diff < 2
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError, match="Not a JPEG file"):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
def test_decode_bad_huffman_images():
# sanity check: make sure we can decode the bad Huffman encoding
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
try:
_ = decode_jpeg(bad_huff)
except RuntimeError:
self.assertTrue(False)
decode_jpeg(bad_huff)
@pytest.mark.parametrize('img_path', [
pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
])
def test_damaged_corrupt_images(img_path):
# Truncated images should raise an exception
truncated_images = glob.glob(
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
for image_path in truncated_images:
data = read_file(image_path)
with self.assertRaises(RuntimeError):
data = read_file(img_path)
if 'corrupt34' in img_path:
match_message = "Image is incomplete or truncated"
else:
match_message = "Unsupported marker type"
with pytest.raises(RuntimeError, match=match_message):
decode_jpeg(data)
def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
for img_path in get_images(FAKEDATA_DIR, ".png"):
for pil_mode, mode in conversion:
@pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(FAKEDATA_DIR, ".png")
])
@pytest.mark.parametrize('pil_mode, mode', [
(None, ImageReadMode.UNCHANGED),
("L", ImageReadMode.GRAY),
("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB),
("RGBA", ImageReadMode.RGB_ALPHA),
])
def test_decode_png(img_path, pil_mode, mode):
with Image.open(img_path) as img:
if pil_mode is not None:
img = img.convert(pil_mode)
......@@ -119,16 +140,22 @@ class ImageTester(unittest.TestCase):
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
tol = 0 if pil_mode is None else 1
assert img_lpng.allclose(img_pil, atol=tol)
with self.assertRaises(RuntimeError):
def test_decode_png_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError, match="Content is not png"):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_encode_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
@pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png")
])
def test_encode_png(img_path):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
......@@ -140,27 +167,29 @@ class ImageTester(unittest.TestCase):
assert_equal(img_pil, rec_img)
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
def test_encode_png_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
with self.assertRaisesRegex(
RuntimeError, "Compression level should be between 0 and 9"):
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=-1)
with self.assertRaisesRegex(
RuntimeError, "Compression level should be between 0 and 9"):
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=10)
with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self):
@pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png")
])
def test_write_png(img_path):
with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
......@@ -173,7 +202,8 @@ class ImageTester(unittest.TestCase):
assert_equal(img_pil, saved_image)
def test_read_file(self):
def test_read_file():
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
......@@ -182,14 +212,14 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
assert_equal(data, expected)
os.unlink(fpath)
assert_equal(data, expected)
with self.assertRaisesRegex(
RuntimeError, "No such file or directory: 'tst'"):
with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
read_file('tst')
def test_read_file_non_ascii(self):
def test_read_file_non_ascii():
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
......@@ -198,10 +228,11 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
assert_equal(data, expected)
os.unlink(fpath)
assert_equal(data, expected)
def test_write_file(self):
def test_write_file():
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
......@@ -210,10 +241,11 @@ class ImageTester(unittest.TestCase):
with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)
assert content == saved_content
def test_write_file_non_ascii(self):
def test_write_file_non_ascii():
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
......@@ -222,8 +254,8 @@ class ImageTester(unittest.TestCase):
with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)
assert content == saved_content
@needs_cuda
......@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path)
img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda')
# Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)
assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
@needs_cuda
......@@ -304,7 +336,11 @@ def _collect_if(cond):
@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows():
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg_windows(img_path):
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
......@@ -315,7 +351,6 @@ def test_encode_jpeg_windows():
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
......@@ -334,10 +369,13 @@ def test_encode_jpeg_windows():
@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows():
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg_windows(img_path):
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
......@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
assert_equal(torch_bytes, pil_bytes)
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment