Commit f46f2c15 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Remove download for ImageNet (#1457)

* remove download process

* address comments

* fix logic error

* bug fixes

* removed unused import

* add docstrings

* flake8

* remove download BC

* fix test

* removed unused code

* flake 8

* add MD5 verification before extraction

* add mock to test

* * unify _verify_archive() method and function
* remove force flag for parse_*_archive functions
* cleanup

* flake8
parent 371f6c8f
......@@ -108,14 +108,14 @@ class Tester(unittest.TestCase):
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.utils.download_url')
@mock.patch('torchvision.datasets.imagenet._verify_archive')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_imagenet(self, mock_download):
def test_imagenet(self, mock_verify):
with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
dataset = torchvision.datasets.ImageNet(root, split='train')
self.generic_classification_dataset_test(dataset)
dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset)
@mock.patch('torchvision.datasets.cifar.check_integrity')
......
from __future__ import print_function
import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
import torch
from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg
ARCHIVE_DICT = {
'train': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
'md5': '1d675b47d978889d74fa0da5fadfb00e',
},
'val': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
'md5': '29b22e2961454d5413ddabcf34fc5622',
},
'devkit': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
'md5': 'fa75699e90414af021442c21a62c3abf',
}
from .utils import check_integrity, extract_archive, verify_str_arg
ARCHIVE_META = {
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
}
META_FILE = "meta.bin"
class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
......@@ -29,9 +22,6 @@ class ImageNet(ImageFolder):
Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
......@@ -47,13 +37,22 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, split='train', download=False, **kwargs):
def __init__(self, root, split='train', download=None, **kwargs):
if download is True:
msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory.")
raise RuntimeError(msg)
elif download is False:
msg = ("The use of the download flag is deprecated, since the dataset "
"is no longer publicly accessible.")
warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))
if download:
self.download()
wnid_to_classes = self._load_meta_file()[0]
self.parse_archives()
wnid_to_classes = load_meta_file(self.root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root
......@@ -65,50 +64,15 @@ class ImageNet(ImageFolder):
for idx, clss in enumerate(self.classes)
for cls in clss}
def download(self):
if not check_integrity(self.meta_file):
tmp_dir = tempfile.mkdtemp()
archive_dict = ARCHIVE_DICT['devkit']
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=tmp_dir,
md5=archive_dict['md5'])
devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
self._save_meta_file(*meta)
shutil.rmtree(tmp_dir)
def parse_archives(self):
if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root)
if not os.path.isdir(self.split_folder):
archive_dict = ARCHIVE_DICT[self.split]
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=self.split_folder,
md5=archive_dict['md5'])
if self.split == 'train':
prepare_train_folder(self.split_folder)
parse_train_archive(self.root)
elif self.split == 'val':
val_wnids = self._load_meta_file()[1]
prepare_val_folder(self.split_folder, val_wnids)
else:
msg = ("You set download=True, but a folder '{}' already exist in "
"the root directory. If you want to re-download or re-extract the "
"archive, delete the folder.")
print(msg.format(self.split))
@property
def meta_file(self):
return os.path.join(self.root, 'meta.bin')
def _load_meta_file(self):
if check_integrity(self.meta_file):
return torch.load(self.meta_file)
else:
raise RuntimeError("Meta file not found or corrupted.",
"You can use download=True to create it.")
def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file)
parse_val_archive(self.root)
@property
def split_folder(self):
......@@ -118,54 +82,137 @@ class ImageNet(ImageFolder):
return "Split: {split}".format(**self.__dict__)
def parse_devkit(root):
idx_to_wnid, wnid_to_classes = parse_meta(root)
val_idcs = parse_val_groundtruth(root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
return wnid_to_classes, val_wnids
def load_meta_file(root, file=None):
if file is None:
file = META_FILE
file = os.path.join(root, file)
if check_integrity(file):
return torch.load(file)
else:
msg = ("The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset.")
raise RuntimeError(msg.format(file, root))
def _verify_archive(root, file, md5):
if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root))
def parse_meta(devkit_root, path='data', filename='meta.mat'):
def parse_devkit_archive(root, file=None):
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
Args:
root (str): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
"""
import scipy.io as sio
metafile = os.path.join(devkit_root, path, filename)
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_meta_mat(devkit_root):
metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root):
file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
@contextmanager
def get_tmp_dir():
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
archive_meta = ARCHIVE_META["devkit"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
_verify_archive(root, file, md5)
with get_tmp_dir() as tmp_dir:
extract_archive(os.path.join(root, file), tmp_dir)
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
val_idcs = parse_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root, file=None, folder="train"):
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
'train'
"""
archive_meta = ARCHIVE_META["train"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
def parse_val_groundtruth(devkit_root, path='data',
filename='ILSVRC2012_validation_ground_truth.txt'):
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
_verify_archive(root, file, md5)
train_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), train_root)
def prepare_train_folder(folder):
for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
for archive in archives:
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def prepare_val_folder(folder, wnids):
img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
def parse_val_archive(root, file=None, wnids=None, folder="val"):
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
for wnid in set(wnids):
os.mkdir(os.path.join(folder, wnid))
Args:
root (str): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
is given, the IDs are loaded from the meta file in the root directory
folder (str, optional): Optional name for validation images folder. Defaults to
'val'
"""
archive_meta = ARCHIVE_META["val"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
if wnids is None:
wnids = load_meta_file(root)[1]
_verify_archive(root, file, md5)
for wnid, img_file in zip(wnids, img_files):
shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root)
images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))
def _splitexts(root):
exts = []
ext = '.'
while ext:
root, ext = os.path.splitext(root)
exts.append(ext)
return root, ''.join(reversed(exts))
for wnid, img_file in zip(wnids, images):
shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment