"tests/data/config/l1.py" did not exist on "7dac84d08c305ec62d6faa05e5aec5fa9b054b47"
gtsrb.py 3.65 KB
Newer Older
Sumukh Aithal's avatar
Sumukh Aithal committed
1
import csv
2
import pathlib
Sumukh Aithal's avatar
Sumukh Aithal committed
3
4
5
6
7
from typing import Any, Callable, Optional, Tuple

import PIL

from .folder import make_dataset
8
from .utils import download_and_extract_archive, verify_str_arg
Sumukh Aithal's avatar
Sumukh Aithal committed
9
10
11
12
13
14
15
16
from .vision import VisionDataset


class GTSRB(VisionDataset):
    """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.

    Args:
        root (string): Root directory of the dataset.
17
        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
anthony-cabacungan's avatar
anthony-cabacungan committed
18
        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
Sumukh Aithal's avatar
Sumukh Aithal committed
19
20
21
22
23
24
25
26
27
28
            version. E.g, ``transforms.RandomCrop``.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    def __init__(
        self,
        root: str,
29
        split: str = "train",
Sumukh Aithal's avatar
Sumukh Aithal committed
30
31
32
33
34
35
36
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__(root, transform=transform, target_transform=target_transform)

37
38
39
40
41
        self._split = verify_str_arg(split, "split", ("train", "test"))
        self._base_folder = pathlib.Path(root) / "gtsrb"
        self._target_folder = (
            self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
        )
Sumukh Aithal's avatar
Sumukh Aithal committed
42
43
44
45
46
47
48

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

49
50
        if self._split == "train":
            samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
Sumukh Aithal's avatar
Sumukh Aithal committed
51
        else:
52
            with open(self._base_folder / "GT-final_test.csv") as csv_file:
Sumukh Aithal's avatar
Sumukh Aithal committed
53
                samples = [
54
                    (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
Sumukh Aithal's avatar
Sumukh Aithal committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                    for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
                ]

        self._samples = samples
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self) -> int:
        return len(self._samples)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self._samples[index]
        sample = PIL.Image.open(path).convert("RGB")

        if self.transform is not None:
            sample = self.transform(sample)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def _check_exists(self) -> bool:
79
        return self._target_folder.is_dir()
Sumukh Aithal's avatar
Sumukh Aithal committed
80
81
82
83
84

    def download(self) -> None:
        if self._check_exists():
            return

85
        base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
Sumukh Aithal's avatar
Sumukh Aithal committed
86

87
88
89
90
91
92
93
94
95
96
97
98
        if self._split == "train":
            download_and_extract_archive(
                f"{base_url}GTSRB-Training_fixed.zip",
                download_root=str(self._base_folder),
                md5="513f3c79a4c5141765e10e952eaa2478",
            )
        else:
            download_and_extract_archive(
                f"{base_url}GTSRB_Final_Test_Images.zip",
                download_root=str(self._base_folder),
                md5="c7e4e6327067d32654124b0fe9e82185",
            )
Sumukh Aithal's avatar
Sumukh Aithal committed
99
            download_and_extract_archive(
100
101
102
                f"{base_url}GTSRB_Final_Test_GT.zip",
                download_root=str(self._base_folder),
                md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
Sumukh Aithal's avatar
Sumukh Aithal committed
103
            )