"test/srt/test_two_batch_overlap.py" did not exist on "46094e0c1b9c81a1f12f356472af694d9ef613cc"
imagenet.py 8.49 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
2
import os
import shutil
3
import tempfile
4
from contextlib import contextmanager
5
6
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
7

Philip Meier's avatar
Philip Meier committed
8
import torch
9

Philip Meier's avatar
Philip Meier committed
10
from .folder import ImageFolder
11
12
13
from .utils import check_integrity, extract_archive, verify_str_arg

ARCHIVE_META = {
14
15
16
    "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
    "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
    "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
Philip Meier's avatar
Philip Meier committed
17
18
}

19
20
META_FILE = "meta.bin"

Philip Meier's avatar
Philip Meier committed
21
22
23
24

class ImageNet(ImageFolder):
    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.

puhuk's avatar
puhuk committed
25
26
27
28
29
30
    .. note::
        Before using this class, it is required to download ImageNet 2012 dataset from
        `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
        place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
        or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.

Philip Meier's avatar
Philip Meier committed
31
    Args:
32
        root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset.
Philip Meier's avatar
Philip Meier committed
33
        split (string, optional): The dataset split, supports ``train``, or ``val``.
anthony-cabacungan's avatar
anthony-cabacungan committed
34
        transform (callable, optional): A function/transform that takes in a PIL image
Philip Meier's avatar
Philip Meier committed
35
36
37
38
39
40
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
Philip Meier's avatar
Philip Meier committed
41
        classes (list): List of the class name tuples.
Philip Meier's avatar
Philip Meier committed
42
43
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
Philip Meier's avatar
Philip Meier committed
44
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
Philip Meier's avatar
Philip Meier committed
45
46
47
48
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

49
    def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None:
Philip Meier's avatar
Philip Meier committed
50
        root = self.root = os.path.expanduser(root)
51
        self.split = verify_str_arg(split, "split", ("train", "val"))
Philip Meier's avatar
Philip Meier committed
52

53
54
        self.parse_archives()
        wnid_to_classes = load_meta_file(self.root)[0]
Philip Meier's avatar
Philip Meier committed
55

56
        super().__init__(self.split_folder, **kwargs)
Philip Meier's avatar
Philip Meier committed
57
58
59
        self.root = root

        self.wnids = self.classes
Philip Meier's avatar
Philip Meier committed
60
        self.wnid_to_idx = self.class_to_idx
Philip Meier's avatar
Philip Meier committed
61
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
62
        self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
Philip Meier's avatar
Philip Meier committed
63

64
    def parse_archives(self) -> None:
65
66
        if not check_integrity(os.path.join(self.root, META_FILE)):
            parse_devkit_archive(self.root)
Philip Meier's avatar
Philip Meier committed
67
68

        if not os.path.isdir(self.split_folder):
69
            if self.split == "train":
70
                parse_train_archive(self.root)
71
            elif self.split == "val":
72
                parse_val_archive(self.root)
Philip Meier's avatar
Philip Meier committed
73
74

    @property
75
    def split_folder(self) -> str:
Philip Meier's avatar
Philip Meier committed
76
77
        return os.path.join(self.root, self.split)

78
    def extra_repr(self) -> str:
79
        return "Split: {split}".format(**self.__dict__)
Philip Meier's avatar
Philip Meier committed
80
81


82
def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
83
84
85
86
87
    if file is None:
        file = META_FILE
    file = os.path.join(root, file)

    if check_integrity(file):
88
        return torch.load(file, weights_only=True)
89
    else:
90
91
92
93
        msg = (
            "The meta file {} is not present in the root directory or is corrupted. "
            "This file is automatically created by the ImageNet dataset."
        )
94
95
        raise RuntimeError(msg.format(file, root))

Philip Meier's avatar
Philip Meier committed
96

97
def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None:
98
    if not check_integrity(os.path.join(root, file), md5):
99
100
101
102
        msg = (
            "The archive {} is not present in the root directory or is corrupted. "
            "You need to download it externally and place it in {}."
        )
103
        raise RuntimeError(msg.format(file, root))
Philip Meier's avatar
Philip Meier committed
104

105

106
def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None:
107
108
109
110
    """Parse the devkit archive of the ImageNet2012 classification dataset and save
    the meta information in a binary file.

    Args:
111
        root (str or ``pathlib.Path``): Root directory containing the devkit archive
112
113
114
        file (str, optional): Name of devkit archive. Defaults to
            'ILSVRC2012_devkit_t12.tar.gz'
    """
Philip Meier's avatar
Philip Meier committed
115
116
    import scipy.io as sio

117
    def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]:
118
        metafile = os.path.join(devkit_root, "data", "meta.mat")
119
        meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
120
        nums_children = list(zip(*meta))[4]
121
        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
122
        idcs, wnids, classes = list(zip(*meta))[:3]
123
        classes = [tuple(clss.split(", ")) for clss in classes]
124
125
126
127
        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
        return idx_to_wnid, wnid_to_classes

128
    def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
129
        file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
130
        with open(file) as txtfh:
131
132
133
134
            val_idcs = txtfh.readlines()
        return [int(val_idx) for val_idx in val_idcs]

    @contextmanager
135
    def get_tmp_dir() -> Iterator[str]:
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        tmp_dir = tempfile.mkdtemp()
        try:
            yield tmp_dir
        finally:
            shutil.rmtree(tmp_dir)

    archive_meta = ARCHIVE_META["devkit"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]

    _verify_archive(root, file, md5)

    with get_tmp_dir() as tmp_dir:
        extract_archive(os.path.join(root, file), tmp_dir)

        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
        val_idcs = parse_val_groundtruth_txt(devkit_root)
        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]

        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))


160
def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None:
161
162
    """Parse the train images archive of the ImageNet2012 classification dataset and
    prepare it for usage with the ImageNet dataset.
Philip Meier's avatar
Philip Meier committed
163

164
    Args:
165
        root (str or ``pathlib.Path``): Root directory containing the train images archive
166
167
168
169
170
171
172
173
174
        file (str, optional): Name of train images archive. Defaults to
            'ILSVRC2012_img_train.tar'
        folder (str, optional): Optional name for train images folder. Defaults to
            'train'
    """
    archive_meta = ARCHIVE_META["train"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]
Philip Meier's avatar
Philip Meier committed
175

176
    _verify_archive(root, file, md5)
Philip Meier's avatar
Philip Meier committed
177

178
179
    train_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), train_root)
Philip Meier's avatar
Philip Meier committed
180

181
182
    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
    for archive in archives:
183
        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
Philip Meier's avatar
Philip Meier committed
184
185


186
def parse_val_archive(
187
    root: Union[str, Path], file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
188
) -> None:
189
190
    """Parse the validation images archive of the ImageNet2012 classification dataset
    and prepare it for usage with the ImageNet dataset.
Philip Meier's avatar
Philip Meier committed
191

192
    Args:
193
        root (str or ``pathlib.Path``): Root directory containing the validation images archive
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        file (str, optional): Name of validation images archive. Defaults to
            'ILSVRC2012_img_val.tar'
        wnids (list, optional): List of WordNet IDs of the validation images. If None
            is given, the IDs are loaded from the meta file in the root directory
        folder (str, optional): Optional name for validation images folder. Defaults to
            'val'
    """
    archive_meta = ARCHIVE_META["val"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]
    if wnids is None:
        wnids = load_meta_file(root)[1]

    _verify_archive(root, file, md5)
Philip Meier's avatar
Philip Meier committed
209

210
211
    val_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), val_root)
Philip Meier's avatar
Philip Meier committed
212

213
    images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
214
215
216

    for wnid in set(wnids):
        os.mkdir(os.path.join(val_root, wnid))
Philip Meier's avatar
Philip Meier committed
217

218
219
    for wnid, img_file in zip(wnids, images):
        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))