sun397.py 4.06 KB
Newer Older
Saswat Das's avatar
Saswat Das committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from pathlib import Path
from typing import Any, Tuple, Callable, Optional

import PIL.Image

from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset


class SUN397(VisionDataset):
    """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.

    The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
    397 categories with 108'754 images. The dataset also provides 10 partitions for training
    and testing, with each partition consisting of 50 images per class.

    Args:
        root (string): Root directory of the dataset.
        split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
        partition (int, optional): A valid partition can be an integer from 1 to 10 or None,
            for the entire dataset.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
            version. E.g, ``transforms.RandomCrop``.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
    """

    _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
    _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
    _PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
    _PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881"

    def __init__(
        self,
        root: str,
        split: str = "train",
        partition: Optional[int] = 1,
        download: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.split = verify_str_arg(split, "split", ("train", "test"))
        self.partition = partition
        self._data_dir = Path(self.root) / "SUN397"

        if self.partition is not None:
            if self.partition < 0 or self.partition > 10:
                raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")

        if download:
            self._download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        with open(self._data_dir / "ClassName.txt") as f:
            self.classes = [c[3:].strip() for c in f]

        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
        if self.partition is not None:
            with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
                self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
        else:
            self._image_files = list(self._data_dir.rglob("sun_*.jpg"))

        self._labels = [
            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
        ]

    def __len__(self) -> int:
        return len(self._image_files)

    def __getitem__(self, idx) -> Tuple[Any, Any]:
        image_file, label = self._image_files[idx], self._labels[idx]
        image = PIL.Image.open(image_file).convert("RGB")

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

    def _check_exists(self) -> bool:
        return self._data_dir.exists() and self._data_dir.is_dir()

    def extra_repr(self) -> str:
        return "Split: {split}".format(**self.__dict__)

    def _download(self) -> None:
        if self._check_exists():
            return
        download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
        download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)