Unverified Commit 7120024d authored by Sumukh Aithal's avatar Sumukh Aithal Committed by GitHub
Browse files

Add GTSRB dataset (#5117)



* Added GTSRB dataset

* Added unittest for GTSRB dataset

* Apply suggestions from code review

* More changes from code review

* readd accidental removed line
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 578c1546
......@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
GTSRB
HD1K
HMDB51
ImageNet
......
......@@ -2275,5 +2275,55 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
return num_samples
class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "GTSRB")
os.makedirs(root_folder, exist_ok=True)
# Train data
train_folder = os.path.join(root_folder, "Training")
os.makedirs(train_folder, exist_ok=True)
num_examples = 3
classes = ("00000", "00042", "00012")
for class_idx in classes:
datasets_utils.create_image_folder(
train_folder,
name=class_idx,
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
num_examples=num_examples,
)
total_number_of_examples = num_examples * len(classes)
# Test data
test_folder = os.path.join(root_folder, "Final_Test", "Images")
os.makedirs(test_folder, exist_ok=True)
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
csv_file.write("Filename;Width;Height;Roi.X1;Roi.Y1;Roi.X2;Roi.Y2;ClassId\n")
for _ in range(total_number_of_examples):
image_file = datasets_utils.create_random_string(5, string.digits) + ".ppm"
datasets_utils.create_image_file(test_folder, image_file)
row = [
image_file,
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(1, 100, size=()).item(),
torch.randint(0, 43, size=()).item(),
]
csv_file.write(";".join(map(str, row)) + "\n")
return total_number_of_examples
if __name__ == "__main__":
unittest.main()
......@@ -10,6 +10,7 @@ from .fer2013 import FER2013
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .inaturalist import INaturalist
......@@ -83,4 +84,5 @@ __all__ = (
"Food101",
"DTD",
"FER2013",
"GTSRB",
)
import csv
import os
from typing import Any, Callable, Optional, Tuple
import PIL
from .folder import make_dataset
from .utils import download_and_extract_archive
from .vision import VisionDataset
class GTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
# Ground Truth for the test set
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
_gt_csv = "GT-final_test.csv"
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
# URLs for the test and train set
_urls = (
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
)
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root)
self.train = train
self._base_folder = os.path.join(self.root, type(self).__name__)
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
if train:
samples = make_dataset(self._target_folder, extensions=(".ppm",))
else:
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
samples = [
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
]
self._samples = samples
self.transform = transform
self.target_transform = target_transform
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self._samples[index]
sample = PIL.Image.open(path).convert("RGB")
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def _check_exists(self) -> bool:
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
def download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
if not self.train:
# Download Ground Truth for the test set
download_and_extract_archive(
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment