Unverified Commit e402d43f authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor test refactorings (#1011)

* Make tests work on fbcode

* Lint

* Fix rebase error

* Properly use get_file_path_2

* Fix wrong use of get_file_path_2 again

* Missing import
parent 67bfb967
import os
import shutil
import tempfile
import contextlib
@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
import os
import sys
import shutil
import contextlib
import tempfile
import unittest
import mock
import numpy as np
import PIL
import torch
import torchvision
from torch._utils_internal import get_file_path_2
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
......@@ -16,20 +15,10 @@ if PYTHON2:
else:
import pickle
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata')
from common_utils import get_tmp_dir
@contextlib.contextmanager
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
FAKEDATA_DIR = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')
@contextlib.contextmanager
......@@ -54,17 +43,14 @@ def get_mnist_data(num_images, cls_name, **kwargs):
f.write(_encode(num_images))
f.write(labels.numpy().tobytes())
tmp_dir = tempfile.mkdtemp(**kwargs)
with get_tmp_dir() as tmp_dir:
raw_dir = os.path.join(tmp_dir, cls_name, "raw")
os.makedirs(raw_dir)
_make_image_file(os.path.join(raw_dir, "train-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "train-labels-idx1-ubyte"), num_images)
_make_image_file(os.path.join(raw_dir, "t10k-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "t10k-labels-idx1-ubyte"), num_images)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
@contextlib.contextmanager
......@@ -109,7 +95,7 @@ def cifar_root(version):
_make_pickled_file(obj, file)
params = _get_version_params(version)
with tmp_dir() as root:
with get_tmp_dir() as root:
base_folder = os.path.join(root, params['base_folder'])
os.mkdir(base_folder)
......@@ -124,7 +110,7 @@ def cifar_root(version):
class Tester(unittest.TestCase):
def test_imagefolder(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
......@@ -200,7 +186,7 @@ class Tester(unittest.TestCase):
@mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 3)
img, target = dataset[0]
......
import os
import shutil
import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile
import tarfile
import gzip
import warnings
from torch._utils_internal import get_file_path_2
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg')
from common_utils import get_tmp_dir
if sys.version_info < (3,):
from urllib2 import URLError
else:
from urllib.error import URLError
TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase):
......@@ -17,48 +27,55 @@ class Tester(unittest.TestCase):
fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
assert utils.check_md5(fpath, correct_md5)
assert not utils.check_md5(fpath, false_md5)
self.assertTrue(utils.check_md5(fpath, correct_md5))
self.assertFalse(utils.check_md5(fpath, false_md5))
def test_check_integrity(self):
existing_fpath = TEST_FILE
nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
assert utils.check_integrity(existing_fpath, correct_md5)
assert not utils.check_integrity(existing_fpath, false_md5)
assert utils.check_integrity(existing_fpath)
assert not utils.check_integrity(nonexisting_fpath)
self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))
def test_download_url(self):
temp_dir = tempfile.mkdtemp()
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)
self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_download_url_retry_http(self):
temp_dir = tempfile.mkdtemp()
with get_tmp_dir() as temp_dir:
url = "https://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)
self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content'
shutil.rmtree(temp_dir)
self.assertEqual(data, 'this is the content')
def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp()
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
......@@ -66,24 +83,22 @@ class Tester(unittest.TestCase):
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
self.assertEqual(data, 'this is the content')
def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp()
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
self.assertTrue(os.path.exists(f_name))
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)
self.assertEqual(data, 'this is the content')
if __name__ == '__main__':
......
from __future__ import division
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
......@@ -18,7 +19,8 @@ try:
except ImportError:
stats = None
GRACE_HOPPER = get_file_path_2('assets/grace_hopper_517x606.jpg')
GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment