caltech.py 8.57 KB
Newer Older
1
2
import os
import os.path
Philip Meier's avatar
Philip Meier committed
3
from typing import Any, Callable, List, Optional, Union, Tuple
4

5
6
from PIL import Image

7
from .utils import download_and_extract_archive, verify_str_arg
8
from .vision import VisionDataset
9
10


11
class Caltech101(VisionDataset):
12
13
    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

14
15
16
17
    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

18
19
20
21
    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
22
23
24
25
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
26
27
28
29
30
31
32
33
34
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
35
    def __init__(
36
37
38
39
40
41
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
42
    ) -> None:
43
        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
44
        os.makedirs(self.root, exist_ok=True)
45
        if isinstance(target_type, str):
46
            target_type = [target_type]
47
        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
48
49
50
51
52

        if download:
            self.download()

        if not self._check_integrity():
53
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
54
55
56
57
58
59
60

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
61
62
63
64
65
66
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
67
68
        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

Philip Meier's avatar
Philip Meier committed
69
        self.index: List[int] = []
70
71
72
73
74
75
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
76
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
77
78
79
80
81
82
83
84
85
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

86
87
88
89
90
        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
91
                f"image_{self.index[index]:04d}.jpg",
92
93
            )
        )
94

Philip Meier's avatar
Philip Meier committed
95
        target: Any = []
96
97
98
99
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
100
101
102
103
104
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
105
                        f"annotation_{self.index[index]:04d}.mat",
106
107
                    )
                )
108
109
110
111
112
113
114
115
116
117
118
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
119
    def _check_integrity(self) -> bool:
120
121
122
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
123
    def __len__(self) -> int:
124
125
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
126
    def download(self) -> None:
127
        if self._check_integrity():
128
            print("Files already downloaded and verified")
129
130
            return

131
        download_and_extract_archive(
132
            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
133
            self.root,
134
            filename="101_ObjectCategories.tar.gz",
135
136
            md5="b224c7392d521a49829488ab0f1120d9",
        )
137
        download_and_extract_archive(
138
            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
139
            self.root,
140
            filename="Annotations.tar",
141
142
            md5="6f83eeb1f24d99cab4eb377263132c91",
        )
143

Philip Meier's avatar
Philip Meier committed
144
    def extra_repr(self) -> str:
145
        return "Target type: {target_type}".format(**self.__dict__)
146
147


148
class Caltech256(VisionDataset):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

Philip Meier's avatar
Philip Meier committed
163
    def __init__(
164
165
166
167
168
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
Philip Meier's avatar
Philip Meier committed
169
    ) -> None:
170
        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
171
        os.makedirs(self.root, exist_ok=True)
172
173
174
175
176

        if download:
            self.download()

        if not self._check_integrity():
177
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
178
179

        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
Philip Meier's avatar
Philip Meier committed
180
        self.index: List[int] = []
181
182
        self.y = []
        for (i, c) in enumerate(self.categories):
Philip Meier's avatar
Philip Meier committed
183
184
185
186
187
188
189
            n = len(
                [
                    item
                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
                    if item.endswith(".jpg")
                ]
            )
190
191
192
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

Philip Meier's avatar
Philip Meier committed
193
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
194
195
196
197
198
199
200
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
201
202
203
204
205
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
206
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
207
208
            )
        )
209
210
211
212
213
214
215
216
217
218
219

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
220
    def _check_integrity(self) -> bool:
221
222
223
        # can be more robust and check hash of files
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

Philip Meier's avatar
Philip Meier committed
224
    def __len__(self) -> int:
225
226
        return len(self.index)

Philip Meier's avatar
Philip Meier committed
227
    def download(self) -> None:
228
        if self._check_integrity():
229
            print("Files already downloaded and verified")
230
231
            return

232
        download_and_extract_archive(
233
            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
234
            self.root,
235
            filename="256_ObjectCategories.tar",
236
237
            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
        )