caltech.py 8.72 KB
Newer Older
1
2
import os
import os.path
3
from pathlib import Path
4
from typing import Any, Callable, List, Optional, Tuple, Union
5

6
7
from PIL import Image

8
from .utils import download_and_extract_archive, verify_str_arg
9
from .vision import VisionDataset
10
11


12
class Caltech101(VisionDataset):
Nicolas Hug's avatar
Nicolas Hug committed
13
    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
14

15
16
17
18
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

19
    Args:
20
        root (str or ``pathlib.Path``): Root directory of dataset where directory
21
22
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
23
24
25
26
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
anthony-cabacungan's avatar
anthony-cabacungan committed
27
        transform (callable, optional): A function/transform that takes in a PIL image
28
29
30
31
32
33
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
34
35
36
37

            .. warning::

                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
38
39
    """

Philip Meier's avatar
Philip Meier committed
40
    def __init__(
41
        self,
42
        root: Union[str, Path],
43
44
45
46
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
47
    ) -> None:
48
        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
49
        os.makedirs(self.root, exist_ok=True)
50
        if isinstance(target_type, str):
51
            target_type = [target_type]
52
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
53
54
55
56
57

        if download:
            self.download()

        if not self._check_integrity():
58
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
59
60
61
62
63
64
65

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
66
67
68
69
70
71
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
72
73
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
74
        self.index: List[int] = []
75
76
77
78
79
80
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
81
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
82
83
84
85
86
87
88
89
90
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

91
92
93
94
95
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
96
                f"image_{self.index[index]:04d}.jpg",
97
98
            )
        )
99

Philip Meier's avatar
Philip Meier committed
100
        target: Any = []
101
102
103
104
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
105
106
107
108
109
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
110
                        f"annotation_{self.index[index]:04d}.mat",
111
112
                    )
                )
113
114
115
116
117
118
119
120
121
122
123
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
124
    def _check_integrity(self) -> bool:
125
126
127
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
128
    def __len__(self) -> int:
129
130
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
131
    def download(self) -> None:
132
        if self._check_integrity():
133
            print("Files already downloaded and verified")
134
135
            return

136
        download_and_extract_archive(
137
            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
138
            self.root,
139
            filename="101_ObjectCategories.tar.gz",
140
141
            md5="b224c7392d521a49829488ab0f1120d9",
        )
142
        download_and_extract_archive(
143
            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
144
            self.root,
145
            filename="Annotations.tar",
146
147
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
148

Philip Meier's avatar
Philip Meier committed
149
    def extra_repr(self) -> str:
150
        return "Target type: {target_type}".format(**self.__dict__)
151
152


153
class Caltech256(VisionDataset):
Nicolas Hug's avatar
Nicolas Hug committed
154
    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
155
156

    Args:
157
        root (str or ``pathlib.Path``): Root directory of dataset where directory
158
            ``caltech256`` exists or will be saved to if download is set to True.
anthony-cabacungan's avatar
anthony-cabacungan committed
159
        transform (callable, optional): A function/transform that takes in a PIL image
160
161
162
163
164
165
166
167
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
168
    def __init__(
169
170
171
172
173
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
174
    ) -> None:
175
        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
176
        os.makedirs(self.root, exist_ok=True)
177
178
179
180
181

        if download:
            self.download()

        if not self._check_integrity():
182
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
183
184

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
185
        self.index: List[int] = []
186
187
        self.y = []
        for (i, c) in enumerate(self.categories):
Philip Meier's avatar
Philip Meier committed
188
189
190
191
192
193
194
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
195
196
197
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
198
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
199
200
201
202
203
204
205
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
206
207
208
209
210
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
211
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
212
213
            )
        )
214
215
216
217
218
219
220
221
222
223
224

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
225
    def _check_integrity(self) -> bool:
226
227
228
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
229
    def __len__(self) -> int:
230
231
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
232
    def download(self) -> None:
233
        if self._check_integrity():
234
            print("Files already downloaded and verified")
235
236
            return

237
        download_and_extract_archive(
238
            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
239
            self.root,
240
            filename="256_ObjectCategories.tar",
241
242
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )