Unverified Commit 1cbcb2b1 authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test/test_image.py to pytest (#3930)

parent 4c563846
...@@ -2,7 +2,6 @@ import glob ...@@ -2,7 +2,6 @@ import glob
import io import io
import os import os
import sys import sys
import unittest
from pathlib import Path from pathlib import Path
import pytest import pytest
...@@ -54,176 +53,209 @@ def normalize_dimensions(img_pil): ...@@ -54,176 +53,209 @@ def normalize_dimensions(img_pil):
return img_pil return img_pil
class ImageTester(unittest.TestCase): @pytest.mark.parametrize('img_path', [
def test_decode_jpeg(self): pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)] for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
for img_path in get_images(IMAGE_ROOT, ".jpg"): ])
for pil_mode, mode in conversion: @pytest.mark.parametrize('pil_mode, mode', [
with Image.open(img_path) as img: (None, ImageReadMode.UNCHANGED),
is_cmyk = img.mode == "CMYK" ("L", ImageReadMode.GRAY),
if pil_mode is not None: ("RGB", ImageReadMode.RGB),
if is_cmyk: ])
# libjpeg does not support the conversion def test_decode_jpeg(img_path, pil_mode, mode):
continue
img = img.convert(pil_mode) with Image.open(img_path) as img:
img_pil = torch.from_numpy(np.array(img)) is_cmyk = img.mode == "CMYK"
if is_cmyk: if pil_mode is not None:
# flip the colors to match libjpeg if is_cmyk:
img_pil = 255 - img_pil # libjpeg does not support the conversion
pytest.xfail("Decoding a CMYK jpeg isn't supported")
img_pil = normalize_dimensions(img_pil) img = img.convert(pil_mode)
data = read_file(img_path) img_pil = torch.from_numpy(np.array(img))
img_ljpeg = decode_image(data, mode=mode) if is_cmyk:
# flip the colors to match libjpeg
# Permit a small variation on pixel values to account for implementation img_pil = 255 - img_pil
# differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() img_pil = normalize_dimensions(img_pil)
self.assertTrue(abs_mean_diff < 2) data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode)
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) # Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
decode_jpeg(torch.empty((100,), dtype=torch.float16)) assert abs_mean_diff < 2
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8)) def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
def test_damaged_images(self): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
# Test image with bad Huffman encoding (should not raise)
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
try: decode_jpeg(torch.empty((100,), dtype=torch.float16))
_ = decode_jpeg(bad_huff)
except RuntimeError: with pytest.raises(RuntimeError, match="Not a JPEG file"):
self.assertTrue(False) decode_jpeg(torch.empty((100), dtype=torch.uint8))
# Truncated images should raise an exception
truncated_images = glob.glob( def test_decode_bad_huffman_images():
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) # sanity check: make sure we can decode the bad Huffman encoding
for image_path in truncated_images: bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
data = read_file(image_path) decode_jpeg(bad_huff)
with self.assertRaises(RuntimeError):
decode_jpeg(data)
@pytest.mark.parametrize('img_path', [
def test_decode_png(self): pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] ])
for img_path in get_images(FAKEDATA_DIR, ".png"): def test_damaged_corrupt_images(img_path):
for pil_mode, mode in conversion: # Truncated images should raise an exception
with Image.open(img_path) as img: data = read_file(img_path)
if pil_mode is not None: if 'corrupt34' in img_path:
img = img.convert(pil_mode) match_message = "Image is incomplete or truncated"
img_pil = torch.from_numpy(np.array(img)) else:
match_message = "Unsupported marker type"
img_pil = normalize_dimensions(img_pil) with pytest.raises(RuntimeError, match=match_message):
data = read_file(img_path) decode_jpeg(data)
img_lpng = decode_image(data, mode=mode)
tol = 0 if conversion is None else 1 @pytest.mark.parametrize('img_path', [
self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(FAKEDATA_DIR, ".png")
with self.assertRaises(RuntimeError): ])
decode_png(torch.empty((), dtype=torch.uint8)) @pytest.mark.parametrize('pil_mode, mode', [
with self.assertRaises(RuntimeError): (None, ImageReadMode.UNCHANGED),
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) ("L", ImageReadMode.GRAY),
("LA", ImageReadMode.GRAY_ALPHA),
def test_encode_png(self): ("RGB", ImageReadMode.RGB),
for img_path in get_images(IMAGE_DIR, '.png'): ("RGBA", ImageReadMode.RGB_ALPHA),
pil_image = Image.open(img_path) ])
img_pil = torch.from_numpy(np.array(pil_image)) def test_decode_png(img_path, pil_mode, mode):
img_pil = img_pil.permute(2, 0, 1)
png_buf = encode_png(img_pil, compression_level=6) with Image.open(img_path) as img:
if pil_mode is not None:
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) img = img.convert(pil_mode)
rec_img = torch.from_numpy(np.array(rec_img)) img_pil = torch.from_numpy(np.array(img))
rec_img = rec_img.permute(2, 0, 1)
img_pil = normalize_dimensions(img_pil)
assert_equal(img_pil, rec_img) data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"): tol = 0 if pil_mode is None else 1
encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) assert img_lpng.allclose(img_pil, atol=tol)
with self.assertRaisesRegex(
RuntimeError, "Compression level should be between 0 and 9"): def test_decode_png_errors():
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
compression_level=-1) decode_png(torch.empty((), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Content is not png"):
with self.assertRaisesRegex( decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
RuntimeError, "Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=10) @pytest.mark.parametrize('img_path', [
pytest.param(png_path, id=_get_safe_image_name(png_path))
with self.assertRaisesRegex( for png_path in get_images(IMAGE_DIR, ".png")
RuntimeError, "The number of channels should be 1 or 3, got: 5"): ])
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) def test_encode_png(img_path):
pil_image = Image.open(img_path)
def test_write_png(self): img_pil = torch.from_numpy(np.array(pil_image))
with get_tmp_dir() as d: img_pil = img_pil.permute(2, 0, 1)
for img_path in get_images(IMAGE_DIR, '.png'): png_buf = encode_png(img_pil, compression_level=6)
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
img_pil = img_pil.permute(2, 0, 1) rec_img = torch.from_numpy(np.array(rec_img))
rec_img = rec_img.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) assert_equal(img_pil, rec_img)
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) def test_encode_png_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
assert_equal(img_pil, saved_image) encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
def test_read_file(self): with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
with get_tmp_dir() as d: encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
fname, content = 'test1.bin', b'TorchVision\211\n' compression_level=-1)
fpath = os.path.join(d, fname)
with open(fpath, 'wb') as f: with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
f.write(content) encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=10)
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
assert_equal(data, expected) encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
os.unlink(fpath)
with self.assertRaisesRegex( @pytest.mark.parametrize('img_path', [
RuntimeError, "No such file or directory: 'tst'"): pytest.param(png_path, id=_get_safe_image_name(png_path))
read_file('tst') for png_path in get_images(IMAGE_DIR, ".png")
])
def test_read_file_non_ascii(self): def test_write_png(img_path):
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' pil_image = Image.open(img_path)
fpath = os.path.join(d, fname) img_pil = torch.from_numpy(np.array(pil_image))
with open(fpath, 'wb') as f: img_pil = img_pil.permute(2, 0, 1)
f.write(content)
filename, _ = os.path.splitext(os.path.basename(img_path))
data = read_file(fpath) torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
expected = torch.tensor(list(content), dtype=torch.uint8) write_png(img_pil, torch_png, compression_level=6)
assert_equal(data, expected) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.unlink(fpath) saved_image = saved_image.permute(2, 0, 1)
def test_write_file(self): assert_equal(img_pil, saved_image)
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) def test_read_file():
content_tensor = torch.tensor(list(content), dtype=torch.uint8) with get_tmp_dir() as d:
write_file(fpath, content_tensor) fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
with open(fpath, 'rb') as f: with open(fpath, 'wb') as f:
saved_content = f.read() f.write(content)
self.assertEqual(content, saved_content)
os.unlink(fpath) data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
def test_write_file_non_ascii(self): os.unlink(fpath)
with get_tmp_dir() as d: assert_equal(data, expected)
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
content_tensor = torch.tensor(list(content), dtype=torch.uint8) read_file('tst')
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f: def test_read_file_non_ascii():
saved_content = f.read() with get_tmp_dir() as d:
self.assertEqual(content, saved_content) fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
os.unlink(fpath) fpath = os.path.join(d, fname)
with open(fpath, 'wb') as f:
f.write(content)
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
os.unlink(fpath)
assert_equal(data, expected)
def test_write_file():
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f:
saved_content = f.read()
os.unlink(fpath)
assert content == saved_content
def test_write_file_non_ascii():
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f:
saved_content = f.read()
os.unlink(fpath)
assert content == saved_content
@needs_cuda @needs_cuda
...@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase): ...@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
def test_decode_jpeg_cuda(mode, img_path, scripted): def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path: if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported") pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path) data = read_file(img_path)
img = decode_image(data, mode=mode) img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda') img_nvjpeg = f(data, mode=mode, device='cuda')
# Some difference expected between jpeg implementations # Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
@needs_cuda @needs_cuda
...@@ -304,7 +336,11 @@ def _collect_if(cond): ...@@ -304,7 +336,11 @@ def _collect_if(cond):
@cpu_only @cpu_only
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows(): @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg_windows(img_path):
# This test is *wrong*. # This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from # starts encoding the torchvision version from an image that comes from
...@@ -315,48 +351,50 @@ def test_encode_jpeg_windows(): ...@@ -315,48 +351,50 @@ def test_encode_jpeg_windows():
# these more correct tests fail on windows (probably because of a difference # these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL. # in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this. # FIXME: make the correct tests pass on windows and remove this.
for img_path in get_images(ENCODE_JPEG, ".jpg"): dirname = os.path.dirname(img_path)
dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path))
filename, _ = os.path.splitext(os.path.basename(img_path)) write_folder = os.path.join(dirname, 'jpeg_write')
write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join(
expected_file = os.path.join( write_folder, '{0}_pil.jpg'.format(filename))
write_folder, '{0}_pil.jpg'.format(filename)) img = decode_jpeg(read_file(img_path))
img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f:
with open(expected_file, 'rb') as f: pil_bytes = f.read()
pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]:
for src_img in [img, img.contiguous()]: # PIL sets jpeg quality to 75 by default
# PIL sets jpeg quality to 75 by default jpeg_bytes = encode_jpeg(src_img, quality=75)
jpeg_bytes = encode_jpeg(src_img, quality=75) assert_equal(jpeg_bytes, pil_bytes)
assert_equal(jpeg_bytes, pil_bytes)
@cpu_only @cpu_only
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows(): @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg_windows(img_path):
# FIXME: Remove this eventually, see test_encode_jpeg_windows # FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d: with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"): data = read_file(img_path)
data = read_file(img_path) img = decode_jpeg(data)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join( torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename)) d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join( pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f: with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read() torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f: with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read() pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes) assert_equal(torch_bytes, pil_bytes)
@cpu_only @cpu_only
...@@ -408,5 +446,5 @@ def test_write_jpeg(img_path): ...@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
assert_equal(torch_bytes, pil_bytes) assert_equal(torch_bytes, pil_bytes)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment