Commit f46f2c15 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Remove download for ImageNet (#1457)

* remove download process

* address comments

* fix logic error

* bug fixes

* removed unused import

* add docstrings

* flake8

* remove download BC

* fix test

* removed unused code

* flake 8

* add MD5 verification before extraction

* add mock to test

* * unify _verify_archive() method and function
* remove force flag for parse_*_archive functions
* cleanup

* flake8
parent 371f6c8f
...@@ -108,14 +108,14 @@ class Tester(unittest.TestCase): ...@@ -108,14 +108,14 @@ class Tester(unittest.TestCase):
img, target = dataset[0] img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.utils.download_url') @mock.patch('torchvision.datasets.imagenet._verify_archive')
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable") @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
def test_imagenet(self, mock_download): def test_imagenet(self, mock_verify):
with imagenet_root() as root: with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True) dataset = torchvision.datasets.ImageNet(root, split='train')
self.generic_classification_dataset_test(dataset) self.generic_classification_dataset_test(dataset)
dataset = torchvision.datasets.ImageNet(root, split='val', download=True) dataset = torchvision.datasets.ImageNet(root, split='val')
self.generic_classification_dataset_test(dataset) self.generic_classification_dataset_test(dataset)
@mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.check_integrity')
......
from __future__ import print_function import warnings
from contextlib import contextmanager
import os import os
import shutil import shutil
import tempfile import tempfile
import torch import torch
from .folder import ImageFolder from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive, \ from .utils import check_integrity, extract_archive, verify_str_arg
verify_str_arg
ARCHIVE_META = {
ARCHIVE_DICT = { 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
'train': { 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
'md5': '1d675b47d978889d74fa0da5fadfb00e',
},
'val': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
'md5': '29b22e2961454d5413ddabcf34fc5622',
},
'devkit': {
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
'md5': 'fa75699e90414af021442c21a62c3abf',
}
} }
META_FILE = "meta.bin"
class ImageNet(ImageFolder): class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset. """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
...@@ -29,9 +22,6 @@ class ImageNet(ImageFolder): ...@@ -29,9 +22,6 @@ class ImageNet(ImageFolder):
Args: Args:
root (string): Root directory of the ImageNet Dataset. root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``. split (string, optional): The dataset split, supports ``train``, or ``val``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
...@@ -47,13 +37,22 @@ class ImageNet(ImageFolder): ...@@ -47,13 +37,22 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset targets (list): The class_index value for each image in the dataset
""" """
def __init__(self, root, split='train', download=False, **kwargs): def __init__(self, root, split='train', download=None, **kwargs):
if download is True:
msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory.")
raise RuntimeError(msg)
elif download is False:
msg = ("The use of the download flag is deprecated, since the dataset "
"is no longer publicly accessible.")
warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root) root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val")) self.split = verify_str_arg(split, "split", ("train", "val"))
if download: self.parse_archives()
self.download() wnid_to_classes = load_meta_file(self.root)[0]
wnid_to_classes = self._load_meta_file()[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs) super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root self.root = root
...@@ -65,50 +64,15 @@ class ImageNet(ImageFolder): ...@@ -65,50 +64,15 @@ class ImageNet(ImageFolder):
for idx, clss in enumerate(self.classes) for idx, clss in enumerate(self.classes)
for cls in clss} for cls in clss}
def download(self): def parse_archives(self):
if not check_integrity(self.meta_file): if not check_integrity(os.path.join(self.root, META_FILE)):
tmp_dir = tempfile.mkdtemp() parse_devkit_archive(self.root)
archive_dict = ARCHIVE_DICT['devkit']
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=tmp_dir,
md5=archive_dict['md5'])
devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
self._save_meta_file(*meta)
shutil.rmtree(tmp_dir)
if not os.path.isdir(self.split_folder): if not os.path.isdir(self.split_folder):
archive_dict = ARCHIVE_DICT[self.split]
download_and_extract_archive(archive_dict['url'], self.root,
extract_root=self.split_folder,
md5=archive_dict['md5'])
if self.split == 'train': if self.split == 'train':
prepare_train_folder(self.split_folder) parse_train_archive(self.root)
elif self.split == 'val': elif self.split == 'val':
val_wnids = self._load_meta_file()[1] parse_val_archive(self.root)
prepare_val_folder(self.split_folder, val_wnids)
else:
msg = ("You set download=True, but a folder '{}' already exist in "
"the root directory. If you want to re-download or re-extract the "
"archive, delete the folder.")
print(msg.format(self.split))
@property
def meta_file(self):
return os.path.join(self.root, 'meta.bin')
def _load_meta_file(self):
if check_integrity(self.meta_file):
return torch.load(self.meta_file)
else:
raise RuntimeError("Meta file not found or corrupted.",
"You can use download=True to create it.")
def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file)
@property @property
def split_folder(self): def split_folder(self):
...@@ -118,54 +82,137 @@ class ImageNet(ImageFolder): ...@@ -118,54 +82,137 @@ class ImageNet(ImageFolder):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def parse_devkit(root): def load_meta_file(root, file=None):
idx_to_wnid, wnid_to_classes = parse_meta(root) if file is None:
val_idcs = parse_val_groundtruth(root) file = META_FILE
val_wnids = [idx_to_wnid[idx] for idx in val_idcs] file = os.path.join(root, file)
return wnid_to_classes, val_wnids
if check_integrity(file):
return torch.load(file)
else:
msg = ("The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset.")
raise RuntimeError(msg.format(file, root))
def _verify_archive(root, file, md5):
if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root))
def parse_meta(devkit_root, path='data', filename='meta.mat'):
def parse_devkit_archive(root, file=None):
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
Args:
root (str): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
"""
import scipy.io as sio import scipy.io as sio
metafile = os.path.join(devkit_root, path, filename) def parse_meta_mat(devkit_root):
meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] metafile = os.path.join(devkit_root, "data", "meta.mat")
nums_children = list(zip(*meta))[4] meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
meta = [meta[idx] for idx, num_children in enumerate(nums_children) nums_children = list(zip(*meta))[4]
if num_children == 0] meta = [meta[idx] for idx, num_children in enumerate(nums_children)
idcs, wnids, classes = list(zip(*meta))[:3] if num_children == 0]
classes = [tuple(clss.split(', ')) for clss in classes] idcs, wnids, classes = list(zip(*meta))[:3]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} classes = [tuple(clss.split(', ')) for clss in classes]
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
return idx_to_wnid, wnid_to_classes wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root):
file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
@contextmanager
def get_tmp_dir():
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
archive_meta = ARCHIVE_META["devkit"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
_verify_archive(root, file, md5)
with get_tmp_dir() as tmp_dir:
extract_archive(os.path.join(root, file), tmp_dir)
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
val_idcs = parse_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root, file=None, folder="train"):
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
'train'
"""
archive_meta = ARCHIVE_META["train"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
def parse_val_groundtruth(devkit_root, path='data', _verify_archive(root, file, md5)
filename='ILSVRC2012_validation_ground_truth.txt'):
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
train_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), train_root)
def prepare_train_folder(folder): archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: for archive in archives:
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def prepare_val_folder(folder, wnids): def parse_val_archive(root, file=None, wnids=None, folder="val"):
img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) """Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
for wnid in set(wnids): Args:
os.mkdir(os.path.join(folder, wnid)) root (str): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
is given, the IDs are loaded from the meta file in the root directory
folder (str, optional): Optional name for validation images folder. Defaults to
'val'
"""
archive_meta = ARCHIVE_META["val"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
if wnids is None:
wnids = load_meta_file(root)[1]
_verify_archive(root, file, md5)
for wnid, img_file in zip(wnids, img_files): val_root = os.path.join(root, folder)
shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) extract_archive(os.path.join(root, file), val_root)
images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))
def _splitexts(root): for wnid, img_file in zip(wnids, images):
exts = [] shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
ext = '.'
while ext:
root, ext = os.path.splitext(root)
exts.append(ext)
return root, ''.join(reversed(exts))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment