caltech.py 8.72 KB
Newer Older
1
2
import os
import os.path
limm's avatar
limm committed
3
4
5
6
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image
7

8
from .utils import download_and_extract_archive, verify_str_arg
limm's avatar
limm committed
9
from .vision import VisionDataset
10
11


12
class Caltech101(VisionDataset):
limm's avatar
limm committed
13
    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
14

15
16
17
18
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

19
    Args:
limm's avatar
limm committed
20
        root (str or ``pathlib.Path``): Root directory of dataset where directory
21
22
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
limm's avatar
limm committed
23
24
25
26
27
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
        transform (callable, optional): A function/transform that takes in a PIL image
28
29
30
31
32
33
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
limm's avatar
limm committed
34
35
36
37

            .. warning::

                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
38
39
    """

Philip Meier's avatar
Philip Meier committed
40
    def __init__(
limm's avatar
limm committed
41
42
43
44
45
46
        self,
        root: Union[str, Path],
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
47
    ) -> None:
limm's avatar
limm committed
48
        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
49
        os.makedirs(self.root, exist_ok=True)
limm's avatar
limm committed
50
        if isinstance(target_type, str):
51
            target_type = [target_type]
limm's avatar
limm committed
52
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
53
54
55
56
57

        if download:
            self.download()

        if not self._check_integrity():
limm's avatar
limm committed
58
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
59
60
61
62
63
64
65

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
limm's avatar
limm committed
66
67
68
69
70
71
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
72
73
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
74
        self.index: List[int] = []
75
76
77
78
79
80
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
81
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
82
83
84
85
86
87
88
89
90
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

limm's avatar
limm committed
91
92
93
94
95
96
97
98
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
                f"image_{self.index[index]:04d}.jpg",
            )
        )
99

Philip Meier's avatar
Philip Meier committed
100
        target: Any = []
101
102
103
104
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
limm's avatar
limm committed
105
106
107
108
109
110
111
112
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
                        f"annotation_{self.index[index]:04d}.mat",
                    )
                )
113
114
115
116
117
118
119
120
121
122
123
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
124
    def _check_integrity(self) -> bool:
125
126
127
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
128
    def __len__(self) -> int:
129
130
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
131
    def download(self) -> None:
132
        if self._check_integrity():
limm's avatar
limm committed
133
            print("Files already downloaded and verified")
134
135
            return

136
        download_and_extract_archive(
limm's avatar
limm committed
137
            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
138
            self.root,
139
            filename="101_ObjectCategories.tar.gz",
limm's avatar
limm committed
140
141
            md5="b224c7392d521a49829488ab0f1120d9",
        )
142
        download_and_extract_archive(
limm's avatar
limm committed
143
            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
144
            self.root,
limm's avatar
limm committed
145
146
147
            filename="Annotations.tar",
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
148

Philip Meier's avatar
Philip Meier committed
149
    def extra_repr(self) -> str:
150
        return "Target type: {target_type}".format(**self.__dict__)
151
152


153
class Caltech256(VisionDataset):
limm's avatar
limm committed
154
    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
155
156

    Args:
limm's avatar
limm committed
157
        root (str or ``pathlib.Path``): Root directory of dataset where directory
158
            ``caltech256`` exists or will be saved to if download is set to True.
limm's avatar
limm committed
159
        transform (callable, optional): A function/transform that takes in a PIL image
160
161
162
163
164
165
166
167
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
168
    def __init__(
limm's avatar
limm committed
169
170
171
172
173
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
174
    ) -> None:
limm's avatar
limm committed
175
        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
176
        os.makedirs(self.root, exist_ok=True)
177
178
179
180
181

        if download:
            self.download()

        if not self._check_integrity():
limm's avatar
limm committed
182
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
183
184

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
185
        self.index: List[int] = []
186
187
        self.y = []
        for (i, c) in enumerate(self.categories):
limm's avatar
limm committed
188
189
190
191
192
193
194
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
195
196
197
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
198
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
199
200
201
202
203
204
205
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
limm's avatar
limm committed
206
207
208
209
210
211
212
213
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
            )
        )
214
215
216
217
218
219
220
221
222
223
224

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
225
    def _check_integrity(self) -> bool:
226
227
228
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
229
    def __len__(self) -> int:
230
231
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
232
    def download(self) -> None:
233
        if self._check_integrity():
limm's avatar
limm committed
234
            print("Files already downloaded and verified")
235
236
            return

237
        download_and_extract_archive(
limm's avatar
limm committed
238
            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
239
            self.root,
240
            filename="256_ObjectCategories.tar",
limm's avatar
limm committed
241
242
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )