Unverified Commit 8e244797 authored by Milos's avatar Milos Committed by GitHub
Browse files

Refactors test_image.py so tests don't write files to assets folder (#3018)

* Fix writing to files by using get_tmp_dir()

* Add ImageReadMode to imports

* Fix failing test due to incorrect image path
parent 4ab46e5f
import os
import io
import glob
import io
import os
import unittest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode)
import numpy as np
from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
......@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
continue
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
for path in image_paths:
if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
yield path
def pil_read_image(img_path):
......@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
......@@ -119,12 +114,12 @@ class ImageTester(unittest.TestCase):
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with self.assertRaisesRegex(
......@@ -140,27 +135,27 @@ class ImageTester(unittest.TestCase):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75)
write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
......@@ -216,20 +211,19 @@ class ImageTester(unittest.TestCase):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png)
saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image))
with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image))
def test_read_file(self):
with get_tmp_dir() as d:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment