sun397.py 2.68 KB
Newer Older
Saswat Das's avatar
Saswat Das committed
1
from pathlib import Path
2
from typing import Any, Callable, Optional, Tuple
Saswat Das's avatar
Saswat Das committed
3
4
5

import PIL.Image

6
from .utils import download_and_extract_archive
Saswat Das's avatar
Saswat Das committed
7
8
9
10
11
12
13
from .vision import VisionDataset


class SUN397(VisionDataset):
    """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.

    The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
14
    397 categories with 108'754 images.
Saswat Das's avatar
Saswat Das committed
15
16
17

    Args:
        root (string): Root directory of the dataset.
anthony-cabacungan's avatar
anthony-cabacungan committed
18
        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
Saswat Das's avatar
Saswat Das committed
19
20
            version. E.g, ``transforms.RandomCrop``.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
21
22
23
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
Saswat Das's avatar
Saswat Das committed
24
25
26
27
28
29
30
31
32
33
    """

    _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
    _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
34
        download: bool = False,
Saswat Das's avatar
Saswat Das committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self._data_dir = Path(self.root) / "SUN397"

        if download:
            self._download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        with open(self._data_dir / "ClassName.txt") as f:
            self.classes = [c[3:].strip() for c in f]

        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
49
        self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
Saswat Das's avatar
Saswat Das committed
50
51
52
53
54
55
56
57

        self._labels = [
            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
        ]

    def __len__(self) -> int:
        return len(self._image_files)

58
    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
Saswat Das's avatar
Saswat Das committed
59
60
61
62
63
64
65
66
67
68
69
70
        image_file, label = self._image_files[idx], self._labels[idx]
        image = PIL.Image.open(image_file).convert("RGB")

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

    def _check_exists(self) -> bool:
71
        return self._data_dir.is_dir()
Saswat Das's avatar
Saswat Das committed
72
73
74
75
76

    def _download(self) -> None:
        if self._check_exists():
            return
        download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)