Unverified Commit e402d43f authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor test refactorings (#1011)

* Make tests work on fbcode

* Lint

* Fix rebase error

* Properly use get_file_path_2

* Fix wrong use of get_file_path_2 again

* Missing import
parent 67bfb967
import os
import shutil
import tempfile
import contextlib
@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
import os import os
import sys import sys
import shutil
import contextlib import contextlib
import tempfile
import unittest import unittest
import mock import mock
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
import torchvision import torchvision
from torch._utils_internal import get_file_path_2
PYTHON2 = sys.version_info[0] == 2 PYTHON2 = sys.version_info[0] == 2
if PYTHON2: if PYTHON2:
...@@ -16,20 +15,10 @@ if PYTHON2: ...@@ -16,20 +15,10 @@ if PYTHON2:
else: else:
import pickle import pickle
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), from common_utils import get_tmp_dir
'assets', 'fakedata')
FAKEDATA_DIR = get_file_path_2(
@contextlib.contextmanager os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -54,17 +43,14 @@ def get_mnist_data(num_images, cls_name, **kwargs): ...@@ -54,17 +43,14 @@ def get_mnist_data(num_images, cls_name, **kwargs):
f.write(_encode(num_images)) f.write(_encode(num_images))
f.write(labels.numpy().tobytes()) f.write(labels.numpy().tobytes())
tmp_dir = tempfile.mkdtemp(**kwargs) with get_tmp_dir() as tmp_dir:
raw_dir = os.path.join(tmp_dir, cls_name, "raw") raw_dir = os.path.join(tmp_dir, cls_name, "raw")
os.makedirs(raw_dir) os.makedirs(raw_dir)
_make_image_file(os.path.join(raw_dir, "train-images-idx3-ubyte"), num_images) _make_image_file(os.path.join(raw_dir, "train-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "train-labels-idx1-ubyte"), num_images) _make_label_file(os.path.join(raw_dir, "train-labels-idx1-ubyte"), num_images)
_make_image_file(os.path.join(raw_dir, "t10k-images-idx3-ubyte"), num_images) _make_image_file(os.path.join(raw_dir, "t10k-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "t10k-labels-idx1-ubyte"), num_images) _make_label_file(os.path.join(raw_dir, "t10k-labels-idx1-ubyte"), num_images)
try:
yield tmp_dir yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -109,7 +95,7 @@ def cifar_root(version): ...@@ -109,7 +95,7 @@ def cifar_root(version):
_make_pickled_file(obj, file) _make_pickled_file(obj, file)
params = _get_version_params(version) params = _get_version_params(version)
with tmp_dir() as root: with get_tmp_dir() as root:
base_folder = os.path.join(root, params['base_folder']) base_folder = os.path.join(root, params['base_folder'])
os.mkdir(base_folder) os.mkdir(base_folder)
...@@ -124,7 +110,7 @@ def cifar_root(version): ...@@ -124,7 +110,7 @@ def cifar_root(version):
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_imagefolder(self): def test_imagefolder(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b']) classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file) class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')] for file in ('a1.png', 'a2.png', 'a3.png')]
...@@ -200,7 +186,7 @@ class Tester(unittest.TestCase): ...@@ -200,7 +186,7 @@ class Tester(unittest.TestCase):
@mock.patch('torchvision.datasets.utils.download_url') @mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download): def test_imagenet(self, mock_download):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root: with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True) dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 3) self.assertEqual(len(dataset), 3)
img, target = dataset[0] img, target = dataset[0]
......
import os import os
import shutil import sys
import tempfile import tempfile
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
import unittest import unittest
import zipfile import zipfile
import tarfile import tarfile
import gzip import gzip
import warnings
from torch._utils_internal import get_file_path_2
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), from common_utils import get_tmp_dir
'assets', 'grace_hopper_517x606.jpg')
if sys.version_info < (3,):
from urllib2 import URLError
else:
from urllib.error import URLError
TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -17,73 +27,78 @@ class Tester(unittest.TestCase): ...@@ -17,73 +27,78 @@ class Tester(unittest.TestCase):
fpath = TEST_FILE fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = '' false_md5 = ''
assert utils.check_md5(fpath, correct_md5) self.assertTrue(utils.check_md5(fpath, correct_md5))
assert not utils.check_md5(fpath, false_md5) self.assertFalse(utils.check_md5(fpath, false_md5))
def test_check_integrity(self): def test_check_integrity(self):
existing_fpath = TEST_FILE existing_fpath = TEST_FILE
nonexisting_fpath = '' nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = '' false_md5 = ''
assert utils.check_integrity(existing_fpath, correct_md5) self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
assert not utils.check_integrity(existing_fpath, false_md5) self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
assert utils.check_integrity(existing_fpath) self.assertTrue(utils.check_integrity(existing_fpath))
assert not utils.check_integrity(nonexisting_fpath) self.assertFalse(utils.check_integrity(nonexisting_fpath))
def test_download_url(self): def test_download_url(self):
temp_dir = tempfile.mkdtemp() with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip" url = "http://github.com/pytorch/vision/archive/master.zip"
utils.download_url(url, temp_dir) try:
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.' utils.download_url(url, temp_dir)
shutil.rmtree(temp_dir) self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_download_url_retry_http(self): def test_download_url_retry_http(self):
temp_dir = tempfile.mkdtemp() with get_tmp_dir() as temp_dir:
url = "https://github.com/pytorch/vision/archive/master.zip" url = "https://github.com/pytorch/vision/archive/master.zip"
utils.download_url(url, temp_dir) try:
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.' utils.download_url(url, temp_dir)
shutil.rmtree(temp_dir) self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_extract_zip(self): def test_extract_zip(self):
temp_dir = tempfile.mkdtemp() with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.zip') as f: with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf: with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content') zf.writestr('file.tst', 'this is the content')
utils.extract_archive(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst')) self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read() data = nf.read()
assert data == 'this is the content' self.assertEqual(data, 'this is the content')
shutil.rmtree(temp_dir)
def test_extract_tar(self): def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']): for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp() with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf: with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode()) bf.write("this is the content".encode())
bf.seek(0) bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f: with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf: with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst') zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst')) self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read() data = nf.read()
assert data == 'this is the content', data self.assertEqual(data, 'this is the content')
shutil.rmtree(temp_dir)
def test_extract_gzip(self): def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp() with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.gz') as f: with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf: with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode()) zf.write('this is the content'.encode())
utils.extract_archive(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name) self.assertTrue(os.path.exists(f_name))
with open(os.path.join(f_name), 'r') as nf: with open(os.path.join(f_name), 'r') as nf:
data = nf.read() data = nf.read()
assert data == 'this is the content', data self.assertEqual(data, 'this is the content')
shutil.rmtree(temp_dir)
if __name__ == '__main__': if __name__ == '__main__':
......
from __future__ import division from __future__ import division
import os
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
...@@ -18,7 +19,8 @@ try: ...@@ -18,7 +19,8 @@ try:
except ImportError: except ImportError:
stats = None stats = None
GRACE_HOPPER = get_file_path_2('assets/grace_hopper_517x606.jpg') GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment