imagenet.py 6.79 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
2
3
from __future__ import print_function
import os
import shutil
4
import tempfile
Philip Meier's avatar
Philip Meier committed
5
6
import torch
from .folder import ImageFolder
7
from .utils import check_integrity, download_and_extract_archive, extract_archive
Philip Meier's avatar
Philip Meier committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

ARCHIVE_DICT = {
    'train': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
        'md5': '1d675b47d978889d74fa0da5fadfb00e',
    },
    'val': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
        'md5': '29b22e2961454d5413ddabcf34fc5622',
    },
    'devkit': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
        'md5': 'fa75699e90414af021442c21a62c3abf',
    }
}


class ImageNet(ImageFolder):
    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
Philip Meier's avatar
Philip Meier committed
41
        classes (list): List of the class name tuples.
Philip Meier's avatar
Philip Meier committed
42
43
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
Philip Meier's avatar
Philip Meier committed
44
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
Philip Meier's avatar
Philip Meier committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = self._verify_split(split)

        if download:
            self.download()
        wnid_to_classes = self._load_meta_file()[0]

        super(ImageNet, self).__init__(self.split_folder, **kwargs)
        self.root = root

        self.wnids = self.classes
Philip Meier's avatar
Philip Meier committed
61
        self.wnid_to_idx = self.class_to_idx
Philip Meier's avatar
Philip Meier committed
62
63
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {cls: idx
Philip Meier's avatar
Philip Meier committed
64
                             for idx, clss in enumerate(self.classes)
Philip Meier's avatar
Philip Meier committed
65
66
67
                             for cls in clss}

    def download(self):
68
        if not check_integrity(self.meta_file):
69
            tmp_dir = tempfile.mkdtemp()
Philip Meier's avatar
Philip Meier committed
70
71

            archive_dict = ARCHIVE_DICT['devkit']
72
73
74
            download_and_extract_archive(archive_dict['url'], self.root,
                                         extract_root=tmp_dir,
                                         md5=archive_dict['md5'])
Philip Meier's avatar
Philip Meier committed
75
            devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
76
            meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
Philip Meier's avatar
Philip Meier committed
77
78
            self._save_meta_file(*meta)

79
            shutil.rmtree(tmp_dir)
Philip Meier's avatar
Philip Meier committed
80
81
82

        if not os.path.isdir(self.split_folder):
            archive_dict = ARCHIVE_DICT[self.split]
83
84
85
            download_and_extract_archive(archive_dict['url'], self.root,
                                         extract_root=self.split_folder,
                                         md5=archive_dict['md5'])
Philip Meier's avatar
Philip Meier committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99

            if self.split == 'train':
                prepare_train_folder(self.split_folder)
            elif self.split == 'val':
                val_wnids = self._load_meta_file()[1]
                prepare_val_folder(self.split_folder, val_wnids)
        else:
            msg = ("You set download=True, but a folder '{}' already exist in "
                   "the root directory. If you want to re-download or re-extract the "
                   "archive, delete the folder.")
            print(msg.format(self.split))

    @property
    def meta_file(self):
100
        return os.path.join(self.root, 'meta.bin')
Philip Meier's avatar
Philip Meier committed
101
102

    def _load_meta_file(self):
103
        if check_integrity(self.meta_file):
Philip Meier's avatar
Philip Meier committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            return torch.load(self.meta_file)
        else:
            raise RuntimeError("Meta file not found or corrupted.",
                               "You can use download=True to create it.")

    def _save_meta_file(self, wnid_to_class, val_wnids):
        torch.save((wnid_to_class, val_wnids), self.meta_file)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'val'

    @property
    def split_folder(self):
        return os.path.join(self.root, self.split)

127
128
    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)
Philip Meier's avatar
Philip Meier committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


def parse_devkit(root):
    idx_to_wnid, wnid_to_classes = parse_meta(root)
    val_idcs = parse_val_groundtruth(root)
    val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
    return wnid_to_classes, val_wnids


def parse_meta(devkit_root, path='data', filename='meta.mat'):
    import scipy.io as sio

    metafile = os.path.join(devkit_root, path, filename)
    meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children)
            if num_children == 0]
    idcs, wnids, classes = list(zip(*meta))[:3]
    classes = [tuple(clss.split(', ')) for clss in classes]
    idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
    wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
    return idx_to_wnid, wnid_to_classes


def parse_val_groundtruth(devkit_root, path='data',
                          filename='ILSVRC2012_validation_ground_truth.txt'):
    with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
        val_idcs = txtfh.readlines()
    return [int(val_idx) for val_idx in val_idcs]


def prepare_train_folder(folder):
    for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
162
        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
Philip Meier's avatar
Philip Meier committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181


def prepare_val_folder(folder, wnids):
    img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])

    for wnid in set(wnids):
        os.mkdir(os.path.join(folder, wnid))

    for wnid, img_file in zip(wnids, img_files):
        shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))


def _splitexts(root):
    exts = []
    ext = '.'
    while ext:
        root, ext = os.path.splitext(root)
        exts.append(ext)
    return root, ''.join(reversed(exts))