places365.py 7.03 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os
from os import path
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urljoin

from .folder import default_loader
from .utils import verify_str_arg, check_integrity, download_and_extract_archive
from .vision import VisionDataset


class Places365(VisionDataset):
    r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.

    Args:
        root (string): Root directory of the Places365 dataset.
Prabhat Roy's avatar
Prabhat Roy committed
16
        split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
Philip Meier's avatar
Philip Meier committed
17
            ``val``.
Philip Meier's avatar
Philip Meier committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the
            high resolution ones.
        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
            downloaded archives are not downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset

    Raises:
        RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted.
        RuntimeError: If ``download is True`` and the image archive is already extracted.
    """
Philip Meier's avatar
Philip Meier committed
38
    _SPLITS = ("train-standard", "train-challenge", "val")
Philip Meier's avatar
Philip Meier committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    _BASE_URL = "http://data.csail.mit.edu/places/places365/"
    # {variant: (archive, md5)}
    _DEVKIT_META = {
        "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
        "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
    }
    # (file, md5)
    _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
    # {split: (file, md5)}
    _FILE_LIST_META = {
        "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
        "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
        "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
    }
    # {(split, small): (file, md5)}
    _IMAGES_META = {
        ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
        ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
        ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
        ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
        ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
        ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
    }

    def __init__(
        self,
        root: str,
        split: str = "train-standard",
        small: bool = False,
        download: bool = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.split = self._verify_split(split)
        self.small = small
        self.loader = loader

        self.classes, self.class_to_idx = self.load_categories(download)
        self.imgs, self.targets = self.load_file_list(download)

        if download:
            self.download_images()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        file, target = self.imgs[index]
        image = self.loader(file)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.imgs)

Philip Meier's avatar
Philip Meier committed
97
98
99
100
    @property
    def variant(self) -> str:
        return "challenge" if "challenge" in self.split else "standard"

Philip Meier's avatar
Philip Meier committed
101
102
    @property
    def images_dir(self) -> str:
Philip Meier's avatar
Philip Meier committed
103
104
105
106
107
108
        size = "256" if self.small else "large"
        if self.split.startswith("train"):
            dir = f"data_{size}_{self.variant}"
        else:
            dir = f"{self.split}_{size}"
        return path.join(self.root, dir)
Philip Meier's avatar
Philip Meier committed
109
110
111
112
113
114
115
116
117
118
119

    def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
        def process(line: str) -> Tuple[str, int]:
            cls, idx = line.split()
            return cls, int(idx)

        file, md5 = self._CATEGORIES_META
        file = path.join(self.root, file)
        if not self._check_integrity(file, md5, download):
            self.download_devkit()

120
        with open(file) as fh:
Philip Meier's avatar
Philip Meier committed
121
122
123
124
125
            class_to_idx = dict(process(line) for line in fh)

        return sorted(class_to_idx.keys()), class_to_idx

    def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
Philip Meier's avatar
Philip Meier committed
126
        def process(line: str, sep="/") -> Tuple[str, int]:
Philip Meier's avatar
Philip Meier committed
127
            image, idx = line.split()
Philip Meier's avatar
Philip Meier committed
128
            return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
Philip Meier's avatar
Philip Meier committed
129
130
131
132
133
134

        file, md5 = self._FILE_LIST_META[self.split]
        file = path.join(self.root, file)
        if not self._check_integrity(file, md5, download):
            self.download_devkit()

135
        with open(file) as fh:
Philip Meier's avatar
Philip Meier committed
136
137
138
139
140
141
            images = [process(line) for line in fh]

        _, targets = zip(*images)
        return images, list(targets)

    def download_devkit(self) -> None:
Philip Meier's avatar
Philip Meier committed
142
        file, md5 = self._DEVKIT_META[self.variant]
Philip Meier's avatar
Philip Meier committed
143
144
145
146
147
148
149
150
151
152
        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)

    def download_images(self) -> None:
        if path.exists(self.images_dir):
            raise RuntimeError(
                f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
                f"delete the directory."
            )

        file, md5 = self._IMAGES_META[(self.split, self.small)]
Philip Meier's avatar
Philip Meier committed
153
154
155
156
        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)

        if self.split.startswith("train"):
            os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
Philip Meier's avatar
Philip Meier committed
157
158
159
160
161
162
163
164

    def extra_repr(self) -> str:
        return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)

    def _verify_split(self, split: str) -> str:
        return verify_str_arg(split, "split", self._SPLITS)

    def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
165
        integrity = check_integrity(file, md5=md5)
Philip Meier's avatar
Philip Meier committed
166
167
168
169
170
        if not integrity and not download:
            raise RuntimeError(
                f"The file {file} does not exist or is corrupted. You can set download=True to download it."
            )
        return integrity