places365.py 7.09 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
2
import os
from os import path
3
4
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Philip Meier's avatar
Philip Meier committed
5
6
7
from urllib.parse import urljoin

from .folder import default_loader
8
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
Philip Meier's avatar
Philip Meier committed
9
10
11
12
13
14
15
from .vision import VisionDataset


class Places365(VisionDataset):
    r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.

    Args:
16
        root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
Prabhat Roy's avatar
Prabhat Roy committed
17
        split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
Philip Meier's avatar
Philip Meier committed
18
            ``val``.
19
        small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
Philip Meier's avatar
Philip Meier committed
20
21
22
            high resolution ones.
        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
            downloaded archives are not downloaded again.
anthony-cabacungan's avatar
anthony-cabacungan committed
23
        transform (callable, optional): A function/transform that takes in a PIL image
Philip Meier's avatar
Philip Meier committed
24
25
26
27
28
29
30
31
32
33
34
35
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset

    Raises:
36
        RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
Philip Meier's avatar
Philip Meier committed
37
38
        RuntimeError: If ``download is True`` and the image archive is already extracted.
    """
Philip Meier's avatar
Philip Meier committed
39
    _SPLITS = ("train-standard", "train-challenge", "val")
Philip Meier's avatar
Philip Meier committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    _BASE_URL = "http://data.csail.mit.edu/places/places365/"
    # {variant: (archive, md5)}
    _DEVKIT_META = {
        "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
        "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
    }
    # (file, md5)
    _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
    # {split: (file, md5)}
    _FILE_LIST_META = {
        "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
        "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
        "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
    }
    # {(split, small): (file, md5)}
    _IMAGES_META = {
        ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
        ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
        ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
        ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
        ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
        ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
    }

    def __init__(
        self,
66
        root: Union[str, Path],
Philip Meier's avatar
Philip Meier committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        split: str = "train-standard",
        small: bool = False,
        download: bool = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.split = self._verify_split(split)
        self.small = small
        self.loader = loader

        self.classes, self.class_to_idx = self.load_categories(download)
        self.imgs, self.targets = self.load_file_list(download)

        if download:
            self.download_images()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        file, target = self.imgs[index]
        image = self.loader(file)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.imgs)

Philip Meier's avatar
Philip Meier committed
98
99
100
101
    @property
    def variant(self) -> str:
        return "challenge" if "challenge" in self.split else "standard"

Philip Meier's avatar
Philip Meier committed
102
103
    @property
    def images_dir(self) -> str:
Philip Meier's avatar
Philip Meier committed
104
105
106
107
108
109
        size = "256" if self.small else "large"
        if self.split.startswith("train"):
            dir = f"data_{size}_{self.variant}"
        else:
            dir = f"{self.split}_{size}"
        return path.join(self.root, dir)
Philip Meier's avatar
Philip Meier committed
110
111
112
113
114
115
116
117
118
119
120

    def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
        def process(line: str) -> Tuple[str, int]:
            cls, idx = line.split()
            return cls, int(idx)

        file, md5 = self._CATEGORIES_META
        file = path.join(self.root, file)
        if not self._check_integrity(file, md5, download):
            self.download_devkit()

121
        with open(file) as fh:
Philip Meier's avatar
Philip Meier committed
122
123
124
125
126
            class_to_idx = dict(process(line) for line in fh)

        return sorted(class_to_idx.keys()), class_to_idx

    def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
Philip Meier's avatar
Philip Meier committed
127
        def process(line: str, sep="/") -> Tuple[str, int]:
Philip Meier's avatar
Philip Meier committed
128
            image, idx = line.split()
Philip Meier's avatar
Philip Meier committed
129
            return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
Philip Meier's avatar
Philip Meier committed
130
131
132
133
134
135

        file, md5 = self._FILE_LIST_META[self.split]
        file = path.join(self.root, file)
        if not self._check_integrity(file, md5, download):
            self.download_devkit()

136
        with open(file) as fh:
Philip Meier's avatar
Philip Meier committed
137
138
139
140
141
142
            images = [process(line) for line in fh]

        _, targets = zip(*images)
        return images, list(targets)

    def download_devkit(self) -> None:
Philip Meier's avatar
Philip Meier committed
143
        file, md5 = self._DEVKIT_META[self.variant]
Philip Meier's avatar
Philip Meier committed
144
145
146
147
148
149
150
151
152
153
        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)

    def download_images(self) -> None:
        if path.exists(self.images_dir):
            raise RuntimeError(
                f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
                f"delete the directory."
            )

        file, md5 = self._IMAGES_META[(self.split, self.small)]
Philip Meier's avatar
Philip Meier committed
154
155
156
157
        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)

        if self.split.startswith("train"):
            os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
Philip Meier's avatar
Philip Meier committed
158
159
160
161
162
163
164
165

    def extra_repr(self) -> str:
        return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)

    def _verify_split(self, split: str) -> str:
        return verify_str_arg(split, "split", self._SPLITS)

    def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
166
        integrity = check_integrity(file, md5=md5)
Philip Meier's avatar
Philip Meier committed
167
168
169
170
171
        if not integrity and not download:
            raise RuntimeError(
                f"The file {file} does not exist or is corrupted. You can set download=True to download it."
            )
        return integrity