Commit 7716aba5 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

[WIP] Add test for ImageNet (#976)

* added fake data

* fixed fake data

* renamed extract and download methods and added functionality

* added raw fake data

* refactored imagenet and added test

* flake8

* added fake devkit and mocked download_url

* reversed uncommenting

* added mock to CI

* fixed tests for imagefolder

* flake8
parent 3d561039
...@@ -34,6 +34,7 @@ before_install: ...@@ -34,6 +34,7 @@ before_install:
fi fi
- pip install future - pip install future
- pip install pytest pytest-cov codecov - pip install pytest pytest-cov codecov
- pip install mock
install: install:
......
import PIL import os
import shutil import shutil
import contextlib
import tempfile import tempfile
import unittest import unittest
import mock
import PIL
import torchvision import torchvision
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata')
@contextlib.contextmanager
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_imagefolder(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)
# test if all classes are present
self.assertEqual(classes, sorted(dataset.classes))
# test if combination of classes and class_to_index functions correctly
for cls in classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
# test if all images were detected correctly
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
# test if the datasets outputs all images correctly
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
# redo all tests with specified valid image files
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x,
is_valid_file=lambda x: '3' in x)
self.assertEqual(classes, sorted(dataset.classes))
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
if '3' in img_file]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
if '3' in img_file]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_mnist(self): def test_mnist(self):
tmp_dir = tempfile.mkdtemp() with tmp_dir() as root:
dataset = torchvision.datasets.MNIST(tmp_dir, download=True) dataset = torchvision.datasets.MNIST(root, download=True)
self.assertEqual(len(dataset), 60000) self.assertEqual(len(dataset), 60000)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int)) self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_kmnist(self): def test_kmnist(self):
tmp_dir = tempfile.mkdtemp() with tmp_dir() as root:
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True) dataset = torchvision.datasets.KMNIST(root, download=True)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int)) self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
def test_fashionmnist(self): def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp() with tmp_dir() as root:
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True) dataset = torchvision.datasets.FashionMNIST(root, download=True)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int)) self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
@mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 3)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)
dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
self.assertEqual(len(dataset), 3)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)
if __name__ == '__main__': if __name__ == '__main__':
......
import os
import shutil
import contextlib
import tempfile
import unittest
from torchvision.datasets import ImageFolder
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata')
@contextlib.contextmanager
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
def mock_transform(return_value, arg_list):
def mock(arg):
arg_list.append(arg)
return return_value
return mock
class Tester(unittest.TestCase):
def test_transform(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
return_value = os.path.join(root, 'a', 'a1.png')
args = []
transform = mock_transform(return_value, args)
dataset = ImageFolder(root, loader=lambda x: x, transform=transform)
outputs = [dataset[i][0] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
imgs = sorted(class_a_image_files + class_b_image_files)
self.assertEqual(imgs, sorted(args))
def test_target_transform(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
return_value = os.path.join(root, 'a', 'a1.png')
args = []
target_transform = mock_transform(return_value, args)
dataset = ImageFolder(root, loader=lambda x: x,
target_transform=target_transform)
outputs = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
targets = sorted([class_a_idx] * len(class_a_image_files) +
[class_b_idx] * len(class_b_image_files))
self.assertEqual(targets, sorted(args))
if __name__ == '__main__':
unittest.main()
...@@ -49,7 +49,7 @@ class Tester(unittest.TestCase): ...@@ -49,7 +49,7 @@ class Tester(unittest.TestCase):
with tempfile.NamedTemporaryFile(suffix='.zip') as f: with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf: with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content') zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst')) assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read() data = nf.read()
...@@ -65,7 +65,7 @@ class Tester(unittest.TestCase): ...@@ -65,7 +65,7 @@ class Tester(unittest.TestCase):
with tempfile.NamedTemporaryFile(suffix=ext) as f: with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf: with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst') zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst')) assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read() data = nf.read()
...@@ -77,7 +77,7 @@ class Tester(unittest.TestCase): ...@@ -77,7 +77,7 @@ class Tester(unittest.TestCase):
with tempfile.NamedTemporaryFile(suffix='.gz') as f: with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf: with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode()) zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir) utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name) assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf: with open(os.path.join(f_name), 'r') as nf:
......
import unittest
import os
from torchvision.datasets import ImageFolder
from torch._utils_internal import get_file_path_2
def mock_transform(return_value, arg_list):
def mock(arg):
arg_list.append(arg)
return return_value
return mock
class Tester(unittest.TestCase):
root = os.path.normpath(get_file_path_2('test/assets/dataset/'))
classes = ['a', 'b']
class_a_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/a/', path)))
for path in ['a1.png', 'a2.png', 'a3.png']]
class_b_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/b/', path)))
for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
def test_image_folder(self):
dataset = ImageFolder(Tester.root, loader=lambda x: x)
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
dataset = ImageFolder(Tester.root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
for cls in Tester.classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images if '3' in img_path]
imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images if '3' in img_path]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_transform(self):
return_value = os.path.normpath(get_file_path_2('test/assets/dataset/a/a1.png'))
args = []
transform = mock_transform(return_value, args)
dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
outputs = [dataset[i][0] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
imgs = sorted(Tester.class_a_images + Tester.class_b_images)
self.assertEqual(imgs, sorted(args))
def test_target_transform(self):
return_value = 1
args = []
target_transform = mock_transform(return_value, args)
dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
outputs = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
targets = sorted([class_a_idx] * len(Tester.class_a_images) +
[class_b_idx] * len(Tester.class_b_images))
self.assertEqual(targets, sorted(args))
if __name__ == '__main__':
unittest.main()
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import os.path import os.path
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_and_extract, makedir_exist_ok from .utils import download_and_extract_archive, makedir_exist_ok
class Caltech101(VisionDataset): class Caltech101(VisionDataset):
...@@ -113,12 +113,12 @@ class Caltech101(VisionDataset): ...@@ -113,12 +113,12 @@ class Caltech101(VisionDataset):
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root, self.root,
"101_ObjectCategories.tar.gz", "101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9") "b224c7392d521a49829488ab0f1120d9")
download_and_extract( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root, self.root,
"101_Annotations.tar", "101_Annotations.tar",
...@@ -201,7 +201,7 @@ class Caltech256(VisionDataset): ...@@ -201,7 +201,7 @@ class Caltech256(VisionDataset):
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root, self.root,
"256_ObjectCategories.tar", "256_ObjectCategories.tar",
......
...@@ -11,7 +11,7 @@ else: ...@@ -11,7 +11,7 @@ else:
import pickle import pickle
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract from .utils import check_integrity, download_and_extract_archive
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
...@@ -147,7 +147,7 @@ class CIFAR10(VisionDataset): ...@@ -147,7 +147,7 @@ class CIFAR10(VisionDataset):
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5) download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self): def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
......
from __future__ import print_function from __future__ import print_function
import os import os
import shutil import shutil
import tempfile
import torch import torch
from .folder import ImageFolder from .folder import ImageFolder
from .utils import check_integrity, download_url from .utils import check_integrity, download_and_extract_archive, extract_archive
ARCHIVE_DICT = { ARCHIVE_DICT = {
'train': { 'train': {
...@@ -66,23 +67,23 @@ class ImageNet(ImageFolder): ...@@ -66,23 +67,23 @@ class ImageNet(ImageFolder):
def download(self): def download(self):
if not check_integrity(self.meta_file): if not check_integrity(self.meta_file):
tmpdir = os.path.join(self.root, 'tmp') tmp_dir = tempfile.mkdtemp()
archive_dict = ARCHIVE_DICT['devkit'] archive_dict = ARCHIVE_DICT['devkit']
download_and_extract_tar(archive_dict['url'], self.root, download_and_extract_archive(archive_dict['url'], self.root,
extract_root=tmpdir, extract_root=tmp_dir,
md5=archive_dict['md5']) md5=archive_dict['md5'])
devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
self._save_meta_file(*meta) self._save_meta_file(*meta)
shutil.rmtree(tmpdir) shutil.rmtree(tmp_dir)
if not os.path.isdir(self.split_folder): if not os.path.isdir(self.split_folder):
archive_dict = ARCHIVE_DICT[self.split] archive_dict = ARCHIVE_DICT[self.split]
download_and_extract_tar(archive_dict['url'], self.root, download_and_extract_archive(archive_dict['url'], self.root,
extract_root=self.split_folder, extract_root=self.split_folder,
md5=archive_dict['md5']) md5=archive_dict['md5'])
if self.split == 'train': if self.split == 'train':
prepare_train_folder(self.split_folder) prepare_train_folder(self.split_folder)
...@@ -128,36 +129,6 @@ class ImageNet(ImageFolder): ...@@ -128,36 +129,6 @@ class ImageNet(ImageFolder):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def extract_tar(src, dest=None, gzip=None, delete=False):
import tarfile
if dest is None:
dest = os.path.dirname(src)
if gzip is None:
gzip = src.lower().endswith('.gz')
mode = 'r:gz' if gzip else 'r'
with tarfile.open(src, mode) as tarfh:
tarfh.extractall(path=dest)
if delete:
os.remove(src)
def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
md5=None, **kwargs):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if filename is None:
filename = os.path.basename(url)
if not check_integrity(os.path.join(download_root, filename), md5):
download_url(url, download_root, filename=filename, md5=md5)
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
def parse_devkit(root): def parse_devkit(root):
idx_to_wnid, wnid_to_classes = parse_meta(root) idx_to_wnid, wnid_to_classes = parse_meta(root)
val_idcs = parse_val_groundtruth(root) val_idcs = parse_val_groundtruth(root)
...@@ -189,7 +160,7 @@ def parse_val_groundtruth(devkit_root, path='data', ...@@ -189,7 +160,7 @@ def parse_val_groundtruth(devkit_root, path='data',
def prepare_train_folder(folder): def prepare_train_folder(folder):
for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
extract_tar(archive, os.path.splitext(archive)[0], delete=True) extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def prepare_val_folder(folder, wnids): def prepare_val_folder(folder, wnids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment