Unverified Commit 8e244797 authored by Milos's avatar Milos Committed by GitHub
Browse files

Refactors test_image.py so tests don't write files to assets folder (#3018)

* Fix writing to files by using get_tmp_dir()

* Add ImageReadMode to imports

* Fix failing test due to incorrect image path
parent 4ab46e5f
import os
import io
import glob import glob
import io
import os
import unittest import unittest
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from common_utils import get_tmp_dir
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode) encode_png, write_png, write_file, ImageReadMode)
import numpy as np
from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
...@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") ...@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
def get_images(directory, img_ext): def get_images(directory, img_ext):
assert os.path.isdir(directory) assert os.path.isdir(directory)
for root, _, files in os.walk(directory): image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}: for path in image_paths:
continue if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
yield path
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
def pil_read_image(img_path): def pil_read_image(img_path):
...@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase): ...@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16)) decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8)) decode_jpeg(torch.empty((100), dtype=torch.uint8))
...@@ -140,6 +135,7 @@ class ImageTester(unittest.TestCase): ...@@ -140,6 +135,7 @@ class ImageTester(unittest.TestCase):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self): def test_write_jpeg(self):
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"): for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path) data = read_file(img_path)
img = decode_jpeg(data) img = decode_jpeg(data)
...@@ -147,7 +143,7 @@ class ImageTester(unittest.TestCase): ...@@ -147,7 +143,7 @@ class ImageTester(unittest.TestCase):
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join( torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename)) d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join( pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
...@@ -159,7 +155,6 @@ class ImageTester(unittest.TestCase): ...@@ -159,7 +155,6 @@ class ImageTester(unittest.TestCase):
with open(pil_jpeg, 'rb') as f: with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read() pil_bytes = f.read()
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes) self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
...@@ -216,17 +211,16 @@ class ImageTester(unittest.TestCase): ...@@ -216,17 +211,16 @@ class ImageTester(unittest.TestCase):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self): def test_write_png(self):
with get_tmp_dir() as d:
for img_path in get_images(IMAGE_DIR, '.png'): for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename)) torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6) write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png)
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image)) self.assertTrue(img_pil.equal(saved_image))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment