Unverified Commit 8e244797 authored by Milos's avatar Milos Committed by GitHub
Browse files

Refactors test_image.py so tests don't write files to assets folder (#3018)

* Fix writing to files by using get_tmp_dir()

* Add ImageReadMode to imports

* Fix failing test due to incorrect image path
parent 4ab46e5f
import os
import io
import glob
import io
import os
import unittest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode)
import numpy as np
from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
......@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
continue
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
for path in image_paths:
if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
yield path
def pil_read_image(img_path):
......@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
......@@ -140,6 +135,7 @@ class ImageTester(unittest.TestCase):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self):
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)
......@@ -147,7 +143,7 @@ class ImageTester(unittest.TestCase):
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename))
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
......@@ -159,7 +155,6 @@ class ImageTester(unittest.TestCase):
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self):
......@@ -216,17 +211,16 @@ class ImageTester(unittest.TestCase):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self):
with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png)
saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment