"tests/vscode:/vscode.git/clone" did not exist on "6105e441426f97f31d96c54d6f35830028c2b3f6"
Commit 53d04ace authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Create imagenet fakedata on-the-fly (#1012)

* create imagenet fakedata on-the-fly

* flake8

* Minor test refactorings (#1011)

* Make tests work on fbcode

* Lint

* Fix rebase error

* Properly use get_file_path_2

* Fix wrong use of get_file_path_2 again

* Missing import

* create imagenet fakedata on-the-fly
parent e402d43f
import os
import sys
import contextlib
import unittest
import mock
import contextlib
import tarfile
import numpy as np
import PIL
from PIL import Image
import torch
import torchvision
from torch._utils_internal import get_file_path_2
import torchvision
from common_utils import get_tmp_dir
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
......@@ -15,15 +18,12 @@ if PYTHON2:
else:
import pickle
from common_utils import get_tmp_dir
FAKEDATA_DIR = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')
@contextlib.contextmanager
def get_mnist_data(num_images, cls_name, **kwargs):
def _encode(v):
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]
......@@ -108,6 +108,76 @@ def cifar_root(version):
yield root
@contextlib.contextmanager
def imagenet_root():
import scipy.io as sio
WNID = 'n01234567'
CLS = 'fakedata'
def _make_image(file):
Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)
def _make_tar(archive, content, arcname=None, compress=False):
mode = 'w:gz' if compress else 'w'
if arcname is None:
arcname = os.path.basename(content)
with tarfile.open(archive, mode) as fh:
fh.add(content, arcname=arcname)
def _make_train_archive(root):
with get_tmp_dir() as tmp:
wnid_dir = os.path.join(tmp, WNID)
os.mkdir(wnid_dir)
_make_image(os.path.join(wnid_dir, WNID + '_1.JPEG'))
wnid_archive = wnid_dir + '.tar'
_make_tar(wnid_archive, wnid_dir)
train_archive = os.path.join(root, 'ILSVRC2012_img_train.tar')
_make_tar(train_archive, wnid_archive)
def _make_val_archive(root):
with get_tmp_dir() as tmp:
val_image = os.path.join(tmp, 'ILSVRC2012_val_00000001.JPEG')
_make_image(val_image)
val_archive = os.path.join(root, 'ILSVRC2012_img_val.tar')
_make_tar(val_archive, val_image)
def _make_devkit_archive(root):
with get_tmp_dir() as tmp:
data_dir = os.path.join(tmp, 'data')
os.mkdir(data_dir)
meta_file = os.path.join(data_dir, 'meta.mat')
synsets = np.core.records.fromarrays([
(0.0, 1.0),
(WNID, ''),
(CLS, ''),
('fakedata for the torchvision testsuite', ''),
(0.0, 1.0),
], names=['ILSVRC2012_ID', 'WNID', 'words', 'gloss', 'num_children'])
sio.savemat(meta_file, {'synsets': synsets})
groundtruth_file = os.path.join(data_dir,
'ILSVRC2012_validation_ground_truth.txt')
with open(groundtruth_file, 'w') as fh:
fh.write('0\n')
devkit_name = 'ILSVRC2012_devkit_t12'
devkit_archive = os.path.join(root, devkit_name + '.tar.gz')
_make_tar(devkit_archive, tmp, arcname=devkit_name, compress=True)
with get_tmp_dir() as root:
_make_train_archive(root)
_make_val_archive(root)
_make_devkit_archive(root)
yield root
class Tester(unittest.TestCase):
def test_imagefolder(self):
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
......@@ -186,20 +256,20 @@ class Tester(unittest.TestCase):
@mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download):
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 3)
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
self.assertEqual(len(dataset), 3)
self.assertEqual(len(dataset), 1)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)
self.assertEqual(dataset.class_to_idx['fakedata'], target)
@mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment