Unverified Commit 87681314 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Clean up jpeg tests (#7820)

parent bf6a8dc2
...@@ -422,77 +422,6 @@ def test_encode_jpeg_errors(): ...@@ -422,77 +422,6 @@ def test_encode_jpeg_errors():
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)
return _inner
@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_encode_jpeg_reference(img_path):
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, "jpeg_write")
expected_file = os.path.join(write_folder, f"{filename}_pil.jpg")
img = decode_jpeg(read_file(img_path))
with open(expected_file, "rb") as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)
@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")
write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, "rb") as f:
torch_bytes = f.read()
with open(pil_jpeg, "rb") as f:
pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes)
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"img_path", "img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
...@@ -511,8 +440,6 @@ def test_encode_jpeg(img_path): ...@@ -511,8 +440,6 @@ def test_encode_jpeg(img_path):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"img_path", "img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment