Unverified Commit 8e244797 authored by Milos's avatar Milos Committed by GitHub
Browse files

Refactors test_image.py so tests don't write files to assets folder (#3018)

* Fix writing to files by using get_tmp_dir()

* Add ImageReadMode to imports

* Fix failing test due to incorrect image path
parent 4ab46e5f
import os
import io
import glob import glob
import io
import os
import unittest import unittest
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from common_utils import get_tmp_dir
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode) encode_png, write_png, write_file, ImageReadMode)
import numpy as np
from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
...@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") ...@@ -22,14 +21,10 @@ ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
def get_images(directory, img_ext): def get_images(directory, img_ext):
assert os.path.isdir(directory) assert os.path.isdir(directory)
for root, _, files in os.walk(directory): image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}: for path in image_paths:
continue if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
yield path
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
def pil_read_image(img_path): def pil_read_image(img_path):
...@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase): ...@@ -75,7 +70,7 @@ class ImageTester(unittest.TestCase):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16)) decode_jpeg(torch.empty((100,), dtype=torch.float16))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8)) decode_jpeg(torch.empty((100), dtype=torch.uint8))
...@@ -119,12 +114,12 @@ class ImageTester(unittest.TestCase): ...@@ -119,12 +114,12 @@ class ImageTester(unittest.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number " ValueError, "Image quality should be a positive number "
"between 1 and 100"): "between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number " ValueError, "Image quality should be a positive number "
"between 1 and 100"): "between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with self.assertRaisesRegex( with self.assertRaisesRegex(
...@@ -140,27 +135,27 @@ class ImageTester(unittest.TestCase): ...@@ -140,27 +135,27 @@ class ImageTester(unittest.TestCase):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self): def test_write_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"): with get_tmp_dir() as d:
data = read_file(img_path) for img_path in get_images(ENCODE_JPEG, ".jpg"):
img = decode_jpeg(data) data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join( torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename)) d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join( pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f: with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read() torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f: with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read() pil_bytes = f.read()
os.remove(torch_jpeg) self.assertEqual(torch_bytes, pil_bytes)
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
...@@ -216,20 +211,19 @@ class ImageTester(unittest.TestCase): ...@@ -216,20 +211,19 @@ class ImageTester(unittest.TestCase):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self): def test_write_png(self):
for img_path in get_images(IMAGE_DIR, '.png'): with get_tmp_dir() as d:
pil_image = Image.open(img_path) for img_path in get_images(IMAGE_DIR, '.png'):
img_pil = torch.from_numpy(np.array(pil_image)) pil_image = Image.open(img_path)
img_pil = img_pil.permute(2, 0, 1) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename)) torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6) write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png) saved_image = saved_image.permute(2, 0, 1)
saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image))
self.assertTrue(img_pil.equal(saved_image))
def test_read_file(self): def test_read_file(self):
with get_tmp_dir() as d: with get_tmp_dir() as d:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment