Unverified Commit eaddb902 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix write and encode jpeg tests (#3908)

parent c58d5d17
import glob import glob
import io import io
import os import os
import sys
import unittest import unittest
from pathlib import Path
import pytest import pytest
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from common_utils import get_tmp_dir, needs_cuda import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda, cpu_only
from _assert_utils import assert_equal from _assert_utils import assert_equal
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode) encode_png, write_png, write_file, ImageReadMode, read_image)
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
IS_WINDOWS = sys.platform in ('win32', 'cygwin')
def _get_safe_image_name(name):
# Used when we need to change the pytest "id" for an "image path" parameter.
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
return name.split(os.path.sep)[-1]
def get_images(directory, img_ext): def get_images(directory, img_ext):
...@@ -93,72 +105,6 @@ class ImageTester(unittest.TestCase): ...@@ -93,72 +105,6 @@ class ImageTester(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_jpeg(data) decode_jpeg(data)
def test_encode_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self):
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
...@@ -282,11 +228,7 @@ class ImageTester(unittest.TestCase): ...@@ -282,11 +228,7 @@ class ImageTester(unittest.TestCase):
@needs_cuda @needs_cuda
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize('img_path', [
# We need to change the "id" for that parameter. pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
# If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
for jpeg_path in get_images(IMAGE_ROOT, ".jpg") for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
]) ])
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) @pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
...@@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors(): ...@@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors():
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
@cpu_only
def test_encode_jpeg_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)
return _inner
@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows():
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)
@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows():
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes)
@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg(img_path):
img = read_image(img_path)
pil_img = F.to_pil_image(img)
buf = io.BytesIO()
pil_img.save(buf, format='JPEG', quality=75)
# pytorch can't read from raw bytes so we go through numpy
pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8)
encoded_jpeg_pil = torch.as_tensor(pil_bytes)
for src_img in [img, img.contiguous()]:
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg(img_path):
with get_tmp_dir() as d:
d = Path(d)
img = read_image(img_path)
pil_img = F.to_pil_image(img)
torch_jpeg = str(d / 'torch.jpg')
pil_jpeg = str(d / 'pil.jpg')
write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment