Unverified Commit c7bcfada authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add torchscript test for io image stuff (#8313)

parent eb815aef
...@@ -79,7 +79,9 @@ def normalize_dimensions(img_pil): ...@@ -79,7 +79,9 @@ def normalize_dimensions(img_pil):
("RGB", ImageReadMode.RGB), ("RGB", ImageReadMode.RGB),
], ],
) )
def test_decode_jpeg(img_path, pil_mode, mode): @pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
with Image.open(img_path) as img: with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK" is_cmyk = img.mode == "CMYK"
...@@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode): ...@@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
img_pil = normalize_dimensions(img_pil) img_pil = normalize_dimensions(img_pil)
data = read_file(img_path) data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode) if scripted:
decode_fun = torch.jit.script(decode_fun)
img_ljpeg = decode_fun(data, mode=mode)
# Permit a small variation on pixel values to account for implementation # Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG. # differences between Pillow and LibJPEG.
...@@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path): ...@@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path):
("RGBA", ImageReadMode.RGB_ALPHA), ("RGBA", ImageReadMode.RGB_ALPHA),
], ],
) )
def test_decode_png(img_path, pil_mode, mode): @pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
if scripted:
decode_fun = torch.jit.script(decode_fun)
with Image.open(img_path) as img: with Image.open(img_path) as img:
if pil_mode is not None: if pil_mode is not None:
...@@ -202,7 +211,7 @@ def test_decode_png(img_path, pil_mode, mode): ...@@ -202,7 +211,7 @@ def test_decode_png(img_path, pil_mode, mode):
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_image(data, mode=mode) img_lpng = decode_fun(data, mode=mode)
img_lpng = _read_png_16(img_path, mode=mode) img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32 assert img_lpng.dtype == torch.int32
...@@ -210,7 +219,7 @@ def test_decode_png(img_path, pil_mode, mode): ...@@ -210,7 +219,7 @@ def test_decode_png(img_path, pil_mode, mode):
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
else: else:
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_image(data, mode=mode) img_lpng = decode_fun(data, mode=mode)
tol = 0 if pil_mode is None else 1 tol = 0 if pil_mode is None else 1
...@@ -239,11 +248,13 @@ def test_decode_png_errors(): ...@@ -239,11 +248,13 @@ def test_decode_png_errors():
"img_path", "img_path",
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
) )
def test_encode_png(img_path): @pytest.mark.parametrize("scripted", (True, False))
def test_encode_png(img_path, scripted):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
png_buf = encode_png(img_pil, compression_level=6) encode = torch.jit.script(encode_png) if scripted else encode_png
png_buf = encode(img_pil, compression_level=6)
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
rec_img = torch.from_numpy(np.array(rec_img)) rec_img = torch.from_numpy(np.array(rec_img))
...@@ -270,27 +281,39 @@ def test_encode_png_errors(): ...@@ -270,27 +281,39 @@ def test_encode_png_errors():
"img_path", "img_path",
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
) )
def test_write_png(img_path, tmpdir): @pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(tmpdir, f"{filename}_torch.png") torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
write_png(img_pil, torch_png, compression_level=6) write = torch.jit.script(write_png) if scripted else write_png
write(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
assert_equal(img_pil, saved_image) assert_equal(img_pil, saved_image)
def test_read_file(tmpdir): def test_read_image():
# Just testing torchcsript, the functionality is somewhat tested already in other tests.
path = next(get_images(IMAGE_ROOT, ".jpg"))
out = read_image(path)
out_scripted = torch.jit.script(read_image)(path)
torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)
@pytest.mark.parametrize("scripted", (True, False))
def test_read_file(tmpdir, scripted):
fname, content = "test1.bin", b"TorchVision\211\n" fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
with open(fpath, "wb") as f: with open(fpath, "wb") as f:
f.write(content) f.write(content)
data = read_file(fpath) fun = torch.jit.script(read_file) if scripted else read_file
data = fun(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) expected = torch.tensor(list(content), dtype=torch.uint8)
os.unlink(fpath) os.unlink(fpath)
assert_equal(data, expected) assert_equal(data, expected)
...@@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir): ...@@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir):
assert_equal(data, expected) assert_equal(data, expected)
def test_write_file(tmpdir): @pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
fname, content = "test1.bin", b"TorchVision\211\n" fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8) content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor) write = torch.jit.script(write_file) if scripted else write_file
write(fpath, content_tensor)
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
saved_content = f.read() saved_content = f.read()
...@@ -464,7 +489,8 @@ def test_encode_jpeg_errors(): ...@@ -464,7 +489,8 @@ def test_encode_jpeg_errors():
"img_path", "img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
) )
def test_encode_jpeg(img_path): @pytest.mark.parametrize("scripted", (True, False))
def test_encode_jpeg(img_path, scripted):
img = read_image(img_path) img = read_image(img_path)
pil_img = F.to_pil_image(img) pil_img = F.to_pil_image(img)
...@@ -473,8 +499,9 @@ def test_encode_jpeg(img_path): ...@@ -473,8 +499,9 @@ def test_encode_jpeg(img_path):
encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8) encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
for src_img in [img, img.contiguous()]: for src_img in [img, img.contiguous()]:
encoded_jpeg_torch = encode_jpeg(src_img, quality=75) encoded_jpeg_torch = encode(src_img, quality=75)
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
...@@ -483,7 +510,8 @@ def test_encode_jpeg(img_path): ...@@ -483,7 +510,8 @@ def test_encode_jpeg(img_path):
"img_path", "img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
) )
def test_write_jpeg(img_path, tmpdir): @pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
tmpdir = Path(tmpdir) tmpdir = Path(tmpdir)
img = read_image(img_path) img = read_image(img_path)
pil_img = F.to_pil_image(img) pil_img = F.to_pil_image(img)
...@@ -491,7 +519,8 @@ def test_write_jpeg(img_path, tmpdir): ...@@ -491,7 +519,8 @@ def test_write_jpeg(img_path, tmpdir):
torch_jpeg = str(tmpdir / "torch.jpg") torch_jpeg = str(tmpdir / "torch.jpg")
pil_jpeg = str(tmpdir / "pil.jpg") pil_jpeg = str(tmpdir / "pil.jpg")
write_jpeg(img, torch_jpeg, quality=75) write = torch.jit.script(write_jpeg) if scripted else write_jpeg
write(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75)
with open(torch_jpeg, "rb") as f: with open(torch_jpeg, "rb") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment