Unverified Commit 240792d4 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

New tests for ImageNet dataset (#3543)

parent 814c4f08
......@@ -312,7 +312,8 @@ class DatasetTestCase(unittest.TestCase):
patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
if "download" in self._HAS_SPECIAL_KWARG:
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
# override download param to False param if its default is truthy
special_kwargs["download"] = False
config.update(other_kwargs)
......
......@@ -143,76 +143,6 @@ def cifar_root(version):
yield root
@contextlib.contextmanager
def imagenet_root():
import scipy.io as sio
WNID = 'n01234567'
CLS = 'fakedata'
def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)
def _make_tar(archive, content, arcname=None, compress=False):
mode = 'w:gz' if compress else 'w'
if arcname is None:
arcname = os.path.basename(content)
with tarfile.open(archive, mode) as fh:
fh.add(content, arcname=arcname)
def _make_train_archive(root):
with get_tmp_dir() as tmp:
wnid_dir = os.path.join(tmp, WNID)
os.mkdir(wnid_dir)
_make_image(os.path.join(wnid_dir, WNID + '_1.JPEG'))
wnid_archive = wnid_dir + '.tar'
_make_tar(wnid_archive, wnid_dir)
train_archive = os.path.join(root, 'ILSVRC2012_img_train.tar')
_make_tar(train_archive, wnid_archive)
def _make_val_archive(root):
with get_tmp_dir() as tmp:
val_image = os.path.join(tmp, 'ILSVRC2012_val_00000001.JPEG')
_make_image(val_image)
val_archive = os.path.join(root, 'ILSVRC2012_img_val.tar')
_make_tar(val_archive, val_image)
def _make_devkit_archive(root):
with get_tmp_dir() as tmp:
data_dir = os.path.join(tmp, 'data')
os.mkdir(data_dir)
meta_file = os.path.join(data_dir, 'meta.mat')
synsets = np.core.records.fromarrays([
(0.0, 1.0),
(WNID, ''),
(CLS, ''),
('fakedata for the torchvision testsuite', ''),
(0.0, 1.0),
], names=['ILSVRC2012_ID', 'WNID', 'words', 'gloss', 'num_children'])
sio.savemat(meta_file, {'synsets': synsets})
groundtruth_file = os.path.join(data_dir,
'ILSVRC2012_validation_ground_truth.txt')
with open(groundtruth_file, 'w') as fh:
fh.write('0\n')
devkit_name = 'ILSVRC2012_devkit_t12'
devkit_archive = os.path.join(root, devkit_name + '.tar.gz')
_make_tar(devkit_archive, tmp, arcname=devkit_name, compress=True)
with get_tmp_dir() as root:
_make_train_archive(root)
_make_val_archive(root)
_make_devkit_archive(root)
yield root
@contextlib.contextmanager
def widerface_root():
"""
......
......@@ -10,7 +10,7 @@ from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, imagenet_root, \
from fakedata_generation import mnist_root, \
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
......@@ -146,16 +146,6 @@ class Tester(DatasetTestcase):
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.imagenet._verify_archive')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_imagenet(self, mock_verify):
with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train')
self.generic_classification_dataset_test(dataset)
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)
@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_widerface(self, mock_check_integrity):
......@@ -490,6 +480,37 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
return num_images_per_category * len(categories)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ('scipy',)
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
wnid = 'n01234567'
if config['split'] == 'train':
num_examples = 3
datasets_utils.create_image_folder(
root=tmpdir,
name=tmpdir / 'train' / wnid / wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG",
num_examples=num_examples,
)
else:
num_examples = 1
datasets_utils.create_image_folder(
root=tmpdir,
name=tmpdir / 'val' / wnid,
file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG",
num_examples=num_examples,
)
wnid_to_classes = {wnid: [1]}
torch.save((wnid_to_classes, None), tmpdir / 'meta.bin')
return num_examples
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment