caltech.py 8.65 KB
Newer Older
1
2
import os
import os.path
3
from typing import Any, Callable, List, Optional, Tuple, Union
4

5
6
from PIL import Image

7
from .utils import download_and_extract_archive, verify_str_arg
8
from .vision import VisionDataset
9
10


11
class Caltech101(VisionDataset):
Nicolas Hug's avatar
Nicolas Hug committed
12
    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
13

14
15
16
17
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

18
19
20
21
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
22
23
24
25
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
anthony-cabacungan's avatar
anthony-cabacungan committed
26
        transform (callable, optional): A function/transform that takes in a PIL image
27
28
29
30
31
32
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
33
34
35
36

            .. warning::

                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
37
38
    """

Philip Meier's avatar
Philip Meier committed
39
    def __init__(
40
41
42
43
44
45
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
46
    ) -> None:
47
        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
48
        os.makedirs(self.root, exist_ok=True)
49
        if isinstance(target_type, str):
50
            target_type = [target_type]
51
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
52
53
54
55
56

        if download:
            self.download()

        if not self._check_integrity():
57
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
58
59
60
61
62
63
64

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
65
66
67
68
69
70
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
71
72
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
73
        self.index: List[int] = []
74
75
76
77
78
79
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
80
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
81
82
83
84
85
86
87
88
89
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

90
91
92
93
94
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
95
                f"image_{self.index[index]:04d}.jpg",
96
97
            )
        )
98

Philip Meier's avatar
Philip Meier committed
99
        target: Any = []
100
101
102
103
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
104
105
106
107
108
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
109
                        f"annotation_{self.index[index]:04d}.mat",
110
111
                    )
                )
112
113
114
115
116
117
118
119
120
121
122
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
123
    def _check_integrity(self) -> bool:
124
125
126
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
127
    def __len__(self) -> int:
128
129
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
130
    def download(self) -> None:
131
        if self._check_integrity():
132
            print("Files already downloaded and verified")
133
134
            return

135
        download_and_extract_archive(
136
            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
137
            self.root,
138
            filename="101_ObjectCategories.tar.gz",
139
140
            md5="b224c7392d521a49829488ab0f1120d9",
        )
141
        download_and_extract_archive(
142
            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
143
            self.root,
144
            filename="Annotations.tar",
145
146
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
147

Philip Meier's avatar
Philip Meier committed
148
    def extra_repr(self) -> str:
149
        return "Target type: {target_type}".format(**self.__dict__)
150
151


152
class Caltech256(VisionDataset):
Nicolas Hug's avatar
Nicolas Hug committed
153
    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
154
155
156
157

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
anthony-cabacungan's avatar
anthony-cabacungan committed
158
        transform (callable, optional): A function/transform that takes in a PIL image
159
160
161
162
163
164
165
166
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
167
    def __init__(
168
169
170
171
172
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
173
    ) -> None:
174
        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
175
        os.makedirs(self.root, exist_ok=True)
176
177
178
179
180

        if download:
            self.download()

        if not self._check_integrity():
181
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
182
183

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
184
        self.index: List[int] = []
185
186
        self.y = []
        for (i, c) in enumerate(self.categories):
Philip Meier's avatar
Philip Meier committed
187
188
189
190
191
192
193
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
194
195
196
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
197
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
198
199
200
201
202
203
204
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
205
206
207
208
209
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
210
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
211
212
            )
        )
213
214
215
216
217
218
219
220
221
222
223

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
224
    def _check_integrity(self) -> bool:
225
226
227
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
228
    def __len__(self) -> int:
229
230
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
231
    def download(self) -> None:
232
        if self._check_integrity():
233
            print("Files already downloaded and verified")
234
235
            return

236
        download_and_extract_archive(
237
            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
238
            self.root,
239
            filename="256_ObjectCategories.tar",
240
241
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )