Unverified Commit 1cbcb2b1 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test/test_image.py to pytest (#3930)

parent 4c563846
...@@ -2,7 +2,6 @@ import glob ...@@ -2,7 +2,6 @@ import glob
import io import io
import os import os
import sys import sys
import unittest
from pathlib import Path from pathlib import Path
import pytest import pytest
...@@ -54,17 +53,23 @@ def normalize_dimensions(img_pil): ...@@ -54,17 +53,23 @@ def normalize_dimensions(img_pil):
return img_pil return img_pil
class ImageTester(unittest.TestCase): @pytest.mark.parametrize('img_path', [
def test_decode_jpeg(self): pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)] for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
for img_path in get_images(IMAGE_ROOT, ".jpg"): ])
for pil_mode, mode in conversion: @pytest.mark.parametrize('pil_mode, mode', [
(None, ImageReadMode.UNCHANGED),
("L", ImageReadMode.GRAY),
("RGB", ImageReadMode.RGB),
])
def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img: with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK" is_cmyk = img.mode == "CMYK"
if pil_mode is not None: if pil_mode is not None:
if is_cmyk: if is_cmyk:
# libjpeg does not support the conversion # libjpeg does not support the conversion
continue pytest.xfail("Decoding a CMYK jpeg isn't supported")
img = img.convert(pil_mode) img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img)) img_pil = torch.from_numpy(np.array(img))
if is_cmyk: if is_cmyk:
...@@ -78,38 +83,54 @@ class ImageTester(unittest.TestCase): ...@@ -78,38 +83,54 @@ class ImageTester(unittest.TestCase):
# Permit a small variation on pixel values to account for implementation # Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG. # differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
self.assertTrue(abs_mean_diff < 2) assert abs_mean_diff < 2
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100,), dtype=torch.float16)) decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError, match="Not a JPEG file"):
decode_jpeg(torch.empty((100), dtype=torch.uint8)) decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise) def test_decode_bad_huffman_images():
# sanity check: make sure we can decode the bad Huffman encoding
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
try: decode_jpeg(bad_huff)
_ = decode_jpeg(bad_huff)
except RuntimeError:
self.assertTrue(False)
@pytest.mark.parametrize('img_path', [
pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
])
def test_damaged_corrupt_images(img_path):
# Truncated images should raise an exception # Truncated images should raise an exception
truncated_images = glob.glob( data = read_file(img_path)
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) if 'corrupt34' in img_path:
for image_path in truncated_images: match_message = "Image is incomplete or truncated"
data = read_file(image_path) else:
with self.assertRaises(RuntimeError): match_message = "Unsupported marker type"
with pytest.raises(RuntimeError, match=match_message):
decode_jpeg(data) decode_jpeg(data)
def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), @pytest.mark.parametrize('img_path', [
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] pytest.param(png_path, id=_get_safe_image_name(png_path))
for img_path in get_images(FAKEDATA_DIR, ".png"): for png_path in get_images(FAKEDATA_DIR, ".png")
for pil_mode, mode in conversion: ])
@pytest.mark.parametrize('pil_mode, mode', [
(None, ImageReadMode.UNCHANGED),
("L", ImageReadMode.GRAY),
("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB),
("RGBA", ImageReadMode.RGB_ALPHA),
])
def test_decode_png(img_path, pil_mode, mode):
with Image.open(img_path) as img: with Image.open(img_path) as img:
if pil_mode is not None: if pil_mode is not None:
img = img.convert(pil_mode) img = img.convert(pil_mode)
...@@ -119,16 +140,22 @@ class ImageTester(unittest.TestCase): ...@@ -119,16 +140,22 @@ class ImageTester(unittest.TestCase):
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_image(data, mode=mode) img_lpng = decode_image(data, mode=mode)
tol = 0 if conversion is None else 1 tol = 0 if pil_mode is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) assert img_lpng.allclose(img_pil, atol=tol)
with self.assertRaises(RuntimeError): def test_decode_png_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_png(torch.empty((), dtype=torch.uint8)) decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError, match="Content is not png"):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_encode_png(self):
for img_path in get_images(IMAGE_DIR, '.png'): @pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png")
])
def test_encode_png(img_path):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
...@@ -140,27 +167,29 @@ class ImageTester(unittest.TestCase): ...@@ -140,27 +167,29 @@ class ImageTester(unittest.TestCase):
assert_equal(img_pil, rec_img) assert_equal(img_pil, rec_img)
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"): def test_encode_png_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
with self.assertRaisesRegex( with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
RuntimeError, "Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=-1) compression_level=-1)
with self.assertRaisesRegex( with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
RuntimeError, "Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=10) compression_level=10)
with self.assertRaisesRegex( with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self):
@pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png")
])
def test_write_png(img_path):
with get_tmp_dir() as d: with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
...@@ -173,7 +202,8 @@ class ImageTester(unittest.TestCase): ...@@ -173,7 +202,8 @@ class ImageTester(unittest.TestCase):
assert_equal(img_pil, saved_image) assert_equal(img_pil, saved_image)
def test_read_file(self):
def test_read_file():
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(d, fname)
...@@ -182,14 +212,14 @@ class ImageTester(unittest.TestCase): ...@@ -182,14 +212,14 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath) data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) expected = torch.tensor(list(content), dtype=torch.uint8)
assert_equal(data, expected)
os.unlink(fpath) os.unlink(fpath)
assert_equal(data, expected)
with self.assertRaisesRegex( with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
RuntimeError, "No such file or directory: 'tst'"):
read_file('tst') read_file('tst')
def test_read_file_non_ascii(self):
def test_read_file_non_ascii():
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(d, fname)
...@@ -198,10 +228,11 @@ class ImageTester(unittest.TestCase): ...@@ -198,10 +228,11 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath) data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) expected = torch.tensor(list(content), dtype=torch.uint8)
assert_equal(data, expected)
os.unlink(fpath) os.unlink(fpath)
assert_equal(data, expected)
def test_write_file(self): def test_write_file():
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(d, fname)
...@@ -210,10 +241,11 @@ class ImageTester(unittest.TestCase): ...@@ -210,10 +241,11 @@ class ImageTester(unittest.TestCase):
with open(fpath, 'rb') as f: with open(fpath, 'rb') as f:
saved_content = f.read() saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath) os.unlink(fpath)
assert content == saved_content
def test_write_file_non_ascii(self): def test_write_file_non_ascii():
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(d, fname)
...@@ -222,8 +254,8 @@ class ImageTester(unittest.TestCase): ...@@ -222,8 +254,8 @@ class ImageTester(unittest.TestCase):
with open(fpath, 'rb') as f: with open(fpath, 'rb') as f:
saved_content = f.read() saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath) os.unlink(fpath)
assert content == saved_content
@needs_cuda @needs_cuda
...@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase): ...@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
def test_decode_jpeg_cuda(mode, img_path, scripted): def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path: if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported") pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path) data = read_file(img_path)
img = decode_image(data, mode=mode) img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda') img_nvjpeg = f(data, mode=mode, device='cuda')
# Some difference expected between jpeg implementations # Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
@needs_cuda @needs_cuda
...@@ -304,7 +336,11 @@ def _collect_if(cond): ...@@ -304,7 +336,11 @@ def _collect_if(cond):
@cpu_only @cpu_only
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows(): @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg_windows(img_path):
# This test is *wrong*. # This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from # starts encoding the torchvision version from an image that comes from
...@@ -315,7 +351,6 @@ def test_encode_jpeg_windows(): ...@@ -315,7 +351,6 @@ def test_encode_jpeg_windows():
# these more correct tests fail on windows (probably because of a difference # these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL. # in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this. # FIXME: make the correct tests pass on windows and remove this.
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path) dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write') write_folder = os.path.join(dirname, 'jpeg_write')
...@@ -334,10 +369,13 @@ def test_encode_jpeg_windows(): ...@@ -334,10 +369,13 @@ def test_encode_jpeg_windows():
@cpu_only @cpu_only
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows(): @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg_windows(img_path):
# FIXME: Remove this eventually, see test_encode_jpeg_windows # FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d: with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path) data = read_file(img_path)
img = decode_jpeg(data) img = decode_jpeg(data)
...@@ -408,5 +446,5 @@ def test_write_jpeg(img_path): ...@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
assert_equal(torch_bytes, pil_bytes) assert_equal(torch_bytes, pil_bytes)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment