Unverified Commit a7a2ee7c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove old CIFAR tests and fake data generation (#3447)

parent 5266a725
...@@ -88,61 +88,6 @@ def mnist_root(num_images, cls_name): ...@@ -88,61 +88,6 @@ def mnist_root(num_images, cls_name):
yield tmp_dir yield tmp_dir
@contextlib.contextmanager
def cifar_root(version):
def _get_version_params(version):
if version == 'CIFAR10':
return {
'base_folder': 'cifar-10-batches-py',
'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)],
'test_file': 'test_batch',
'target_key': 'labels',
'meta_file': 'batches.meta',
'classes_key': 'label_names',
}
elif version == 'CIFAR100':
return {
'base_folder': 'cifar-100-python',
'train_files': ['train'],
'test_file': 'test',
'target_key': 'fine_labels',
'meta_file': 'meta',
'classes_key': 'fine_label_names',
}
else:
raise ValueError
def _make_pickled_file(obj, file):
with open(file, 'wb') as fh:
pickle.dump(obj, fh, 2)
def _make_data_file(file, target_key):
obj = {
'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8),
target_key: [0]
}
_make_pickled_file(obj, file)
def _make_meta_file(file, classes_key):
obj = {
classes_key: ['fakedata'],
}
_make_pickled_file(obj, file)
params = _get_version_params(version)
with get_tmp_dir() as root:
base_folder = os.path.join(root, params['base_folder'])
os.mkdir(base_folder)
for file in list(params['train_files']) + [params['test_file']]:
_make_data_file(os.path.join(base_folder, file), params['target_key'])
_make_meta_file(os.path.join(base_folder, params['meta_file']),
params['classes_key'])
yield root
@contextlib.contextmanager @contextlib.contextmanager
def imagenet_root(): def imagenet_root():
import scipy.io as sio import scipy.io as sio
......
...@@ -10,7 +10,7 @@ from torch._utils_internal import get_file_path_2 ...@@ -10,7 +10,7 @@ from torch._utils_internal import get_file_path_2
import torchvision import torchvision
from torchvision.datasets import utils from torchvision.datasets import utils
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ from fakedata_generation import mnist_root, imagenet_root, \
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
...@@ -173,38 +173,6 @@ class Tester(DatasetTestcase): ...@@ -173,38 +173,6 @@ class Tester(DatasetTestcase):
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(img, PIL.Image.Image))
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar10(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR10') as root:
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
self.generic_classification_dataset_test(dataset, num_images=5)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
def test_cifar100(self, mock_ext_check, mock_int_check):
mock_ext_check.return_value = True
mock_int_check.return_value = True
with cifar_root('CIFAR100') as root:
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_cityscapes(self): def test_cityscapes(self):
with cityscapes_root() as root: with cityscapes_root() as root:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment