Commit 67bfb967 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Add test for CIFAR10(0) (#1010)

* added test for CIFAR10(0)

* create CIFAR data on-the-fly

* flake8

* fixed typo

* removed falsely added import
parent ac2e995a
import os import os
import sys
import shutil import shutil
import contextlib import contextlib
import tempfile import tempfile
import unittest import unittest
import mock import mock
import numpy as np
import PIL import PIL
import torch import torch
import torchvision import torchvision
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
import cPickle as pickle
else:
import pickle
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata') 'assets', 'fakedata')
...@@ -59,8 +67,62 @@ def get_mnist_data(num_images, cls_name, **kwargs): ...@@ -59,8 +67,62 @@ def get_mnist_data(num_images, cls_name, **kwargs):
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
class Tester(unittest.TestCase): @contextlib.contextmanager
def cifar_root(version):
def _get_version_params(version):
if version == 'CIFAR10':
return {
'base_folder': 'cifar-10-batches-py',
'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)],
'test_file': 'test_batch',
'target_key': 'labels',
'meta_file': 'batches.meta',
'classes_key': 'label_names',
}
elif version == 'CIFAR100':
return {
'base_folder': 'cifar-100-python',
'train_files': ['train'],
'test_file': 'test',
'target_key': 'fine_labels',
'meta_file': 'meta',
'classes_key': 'fine_label_names',
}
else:
raise ValueError
def _make_pickled_file(obj, file):
with open(file, 'wb') as fh:
pickle.dump(obj, fh, 2)
def _make_data_file(file, target_key):
obj = {
'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8),
target_key: [0]
}
_make_pickled_file(obj, file)
def _make_meta_file(file, classes_key):
obj = {
classes_key: ['fakedata'],
}
_make_pickled_file(obj, file)
params = _get_version_params(version)
with tmp_dir() as root:
base_folder = os.path.join(root, params['base_folder'])
os.mkdir(base_folder)
for file in list(params['train_files']) + [params['test_file']]:
_make_data_file(os.path.join(base_folder, file), params['target_key'])
_make_meta_file(os.path.join(base_folder, params['meta_file']),
params['classes_key'])
yield root
class Tester(unittest.TestCase):
def test_imagefolder(self): def test_imagefolder(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b']) classes = sorted(['a', 'b'])
...@@ -153,6 +215,46 @@ class Tester(unittest.TestCase): ...@@ -153,6 +215,46 @@ class Tester(unittest.TestCase):
self.assertTrue(isinstance(target, int)) self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target) self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR10') as root:
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
self.assertEqual(len(dataset), 5)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar100(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR100') as root:
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment