imagenet.py 6.86 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
2
3
from __future__ import print_function
import os
import shutil
4
import tempfile
Philip Meier's avatar
Philip Meier committed
5
6
import torch
from .folder import ImageFolder
7
from .utils import check_integrity, download_and_extract_archive, extract_archive
Philip Meier's avatar
Philip Meier committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

ARCHIVE_DICT = {
    'train': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
        'md5': '1d675b47d978889d74fa0da5fadfb00e',
    },
    'val': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
        'md5': '29b22e2961454d5413ddabcf34fc5622',
    },
    'devkit': {
        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
        'md5': 'fa75699e90414af021442c21a62c3abf',
    }
}


class ImageNet(ImageFolder):
    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
Philip Meier's avatar
Philip Meier committed
44
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
Philip Meier's avatar
Philip Meier committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = self._verify_split(split)

        if download:
            self.download()
        wnid_to_classes = self._load_meta_file()[0]

        super(ImageNet, self).__init__(self.split_folder, **kwargs)
        self.root = root

        idcs = [idx for _, idx in self.imgs]
        self.wnids = self.classes
        self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)}
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {cls: idx
                             for clss, idx in zip(self.classes, idcs)
                             for cls in clss}

    def download(self):
69
        if not check_integrity(self.meta_file):
70
            tmp_dir = tempfile.mkdtemp()
Philip Meier's avatar
Philip Meier committed
71
72

            archive_dict = ARCHIVE_DICT['devkit']
73
74
75
            download_and_extract_archive(archive_dict['url'], self.root,
                                         extract_root=tmp_dir,
                                         md5=archive_dict['md5'])
Philip Meier's avatar
Philip Meier committed
76
            devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
77
            meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
Philip Meier's avatar
Philip Meier committed
78
79
            self._save_meta_file(*meta)

80
            shutil.rmtree(tmp_dir)
Philip Meier's avatar
Philip Meier committed
81
82
83

        if not os.path.isdir(self.split_folder):
            archive_dict = ARCHIVE_DICT[self.split]
84
85
86
            download_and_extract_archive(archive_dict['url'], self.root,
                                         extract_root=self.split_folder,
                                         md5=archive_dict['md5'])
Philip Meier's avatar
Philip Meier committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100

            if self.split == 'train':
                prepare_train_folder(self.split_folder)
            elif self.split == 'val':
                val_wnids = self._load_meta_file()[1]
                prepare_val_folder(self.split_folder, val_wnids)
        else:
            msg = ("You set download=True, but a folder '{}' already exist in "
                   "the root directory. If you want to re-download or re-extract the "
                   "archive, delete the folder.")
            print(msg.format(self.split))

    @property
    def meta_file(self):
101
        return os.path.join(self.root, 'meta.bin')
Philip Meier's avatar
Philip Meier committed
102
103

    def _load_meta_file(self):
104
        if check_integrity(self.meta_file):
Philip Meier's avatar
Philip Meier committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            return torch.load(self.meta_file)
        else:
            raise RuntimeError("Meta file not found or corrupted.",
                               "You can use download=True to create it.")

    def _save_meta_file(self, wnid_to_class, val_wnids):
        torch.save((wnid_to_class, val_wnids), self.meta_file)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'val'

    @property
    def split_folder(self):
        return os.path.join(self.root, self.split)

128
129
    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)
Philip Meier's avatar
Philip Meier committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162


def parse_devkit(root):
    idx_to_wnid, wnid_to_classes = parse_meta(root)
    val_idcs = parse_val_groundtruth(root)
    val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
    return wnid_to_classes, val_wnids


def parse_meta(devkit_root, path='data', filename='meta.mat'):
    import scipy.io as sio

    metafile = os.path.join(devkit_root, path, filename)
    meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children)
            if num_children == 0]
    idcs, wnids, classes = list(zip(*meta))[:3]
    classes = [tuple(clss.split(', ')) for clss in classes]
    idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
    wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
    return idx_to_wnid, wnid_to_classes


def parse_val_groundtruth(devkit_root, path='data',
                          filename='ILSVRC2012_validation_ground_truth.txt'):
    with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
        val_idcs = txtfh.readlines()
    return [int(val_idx) for val_idx in val_idcs]


def prepare_train_folder(folder):
    for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
163
        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
Philip Meier's avatar
Philip Meier committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182


def prepare_val_folder(folder, wnids):
    img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])

    for wnid in set(wnids):
        os.mkdir(os.path.join(folder, wnid))

    for wnid, img_file in zip(wnids, img_files):
        shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))


def _splitexts(root):
    exts = []
    ext = '.'
    while ext:
        root, ext = os.path.splitext(root)
        exts.append(ext)
    return root, ''.join(reversed(exts))