Unverified Commit 508c79de authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add GTSRB dataset to prototypes (#5214)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 8886a3cf
...@@ -1017,6 +1017,76 @@ def fer2013(info, root, config): ...@@ -1017,6 +1017,76 @@ def fer2013(info, root, config):
return num_samples return num_samples
@DATASET_MOCKS.set_from_named_callable
def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012")
num_examples = num_examples_per_class * len(classes)
csv_columns = ["Filename", "Width", "Height", "Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2", "ClassId"]
def _make_ann_file(path, num_examples, class_idx):
if class_idx == "random":
class_idx = torch.randint(1, len(classes) + 1, size=(1,)).item()
with open(path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=csv_columns, delimiter=";")
writer.writeheader()
for image_idx in range(num_examples):
writer.writerow(
{
"Filename": f"{image_idx:05d}.ppm",
"Width": torch.randint(1, 100, size=()).item(),
"Height": torch.randint(1, 100, size=()).item(),
"Roi.X1": torch.randint(1, 100, size=()).item(),
"Roi.Y1": torch.randint(1, 100, size=()).item(),
"Roi.X2": torch.randint(1, 100, size=()).item(),
"Roi.Y2": torch.randint(1, 100, size=()).item(),
"ClassId": class_idx,
}
)
if config["split"] == "train":
train_folder = root / "GTSRB" / "Training"
train_folder.mkdir(parents=True)
for class_idx in classes:
create_image_folder(
train_folder,
name=class_idx,
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
num_examples=num_examples_per_class,
)
_make_ann_file(
path=train_folder / class_idx / f"GT-{class_idx}.csv",
num_examples=num_examples_per_class,
class_idx=int(class_idx),
)
make_zip(root, "GTSRB-Training_fixed.zip", train_folder)
else:
test_folder = root / "GTSRB" / "Final_Test"
test_folder.mkdir(parents=True)
create_image_folder(
test_folder,
name="Images",
file_name_fn=lambda image_idx: f"{image_idx:05d}.ppm",
num_examples=num_examples,
)
make_zip(root, "GTSRB_Final_Test_Images.zip", test_folder)
_make_ann_file(
path=root / "GT-final_test.csv",
num_examples=num_examples,
class_idx="random",
)
make_zip(root, "GTSRB_Final_Test_GT.zip", "GT-final_test.csv")
return num_examples
@DATASET_MOCKS.set_from_named_callable @DATASET_MOCKS.set_from_named_callable
def clevr(info, root, config): def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0" data_folder = root / "CLEVR_v1.0"
......
...@@ -881,7 +881,7 @@ def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): ...@@ -881,7 +881,7 @@ def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
files, dirs = _split_files_or_dirs(root, *files_or_dirs) files, dirs = _split_files_or_dirs(root, *files_or_dirs)
with opener(archive) as fh: with opener(archive) as fh:
for file in files: for file in sorted(files):
adder(fh, file, file.relative_to(root)) adder(fh, file, file.relative_to(root))
if remove: if remove:
......
import io import io
from pathlib import Path
import pytest import pytest
import torch import torch
...@@ -123,7 +124,7 @@ class TestCommon: ...@@ -123,7 +124,7 @@ class TestCommon:
if type(dp) is annotation_dp_type: if type(dp) is annotation_dp_type:
break break
else: else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
...@@ -143,3 +144,19 @@ class TestQMNIST: ...@@ -143,3 +144,19 @@ class TestQMNIST:
("unused", bool), ("unused", bool),
): ):
assert key in sample and isinstance(sample[key], type) assert key in sample and isinstance(sample[key], type)
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB:
def test_label_matches_path(self, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config.split != "train":
return
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
for sample in dataset:
label_from_path = int(Path(sample["image_path"]).parent.name)
assert sample["label"] == label_from_path
# How to add new built-in prototype datasets # How to add new built-in prototype datasets
As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means that this document will also change a lot. As the name implies, the datasets are still in a prototype state and thus
subject to rapid change. This in turn means that this document will also change
a lot.
If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out. If you hit a blocker while adding a dataset, please have a look at another
similar dataset to see how it is implemented there. If you can't resolve it
yourself, feel free to send a draft PR in order for us to help you out.
Finally, `from torchvision.prototype import datasets` is implied below. Finally, `from torchvision.prototype import datasets` is implied below.
## Implementation ## Implementation
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be discussed in detail below: Before we start with the actual implementation, you should create a module in
`torchvision/prototype/datasets/_builtin` that hints at the dataset you are
going to add. For example `caltech.py` for `caltech101` and `caltech256`. In
that module create a class that inherits from `datasets.utils.Dataset` and
overwrites at minimum three methods that will be discussed in detail below:
```python ```python
import io import io
...@@ -37,27 +45,54 @@ class MyDataset(Dataset): ...@@ -37,27 +45,54 @@ class MyDataset(Dataset):
### `_make_info(self)` ### `_make_info(self)`
The `DatasetInfo` carries static information about the dataset. There are two required fields: The `DatasetInfo` carries static information about the dataset. There are two
- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain lower characters. required fields:
- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select the default decoder in case the user doesn't pass one. There are currently only two options: `IMAGE` and `RAW` ([see below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details). - `name`: Name of the dataset. This will be used to load the dataset with
`datasets.load(name)`. Should only contain lowercase characters.
- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select
the default decoder in case the user doesn't pass one. There are currently
only two options: `IMAGE` and `RAW` ([see
below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details).
There are more optional parameters that can be passed: There are more optional parameters that can be passed:
- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their availability will be automatically checked if a user tries to load the dataset. Within the implementation, import these packages lazily to avoid missing dependencies at import time. - `dependencies`: Collection of third-party dependencies that are needed to load
- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the corresponding label returned in the dataset samples. [See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. the dataset, e.g. `("scipy",)`. Their availability will be automatically
- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[str]]`. The options are accessible through the `config` namespace in the other two functions. First value of the sequence is taken as default if the user passes no option to `torchvision.prototype.datasets.load()`. checked if a user tries to load the dataset. Within the implementation, import
these packages lazily to avoid missing dependencies at import time.
- `categories`: Sequence of human-readable category names for each label. The
index of each category has to match the corresponding label returned in the
dataset samples. [See
below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle
cases with many categories.
- `valid_options`: Configures valid options that can be passed to the dataset.
It should be `Dict[str, Sequence[Any]]`. The options are accessible through
the `config` namespace in the other two functions. First value of the sequence
is taken as default if the user passes no option to
`torchvision.prototype.datasets.load()`.
## `resources(self, config)` ## `resources(self, config)`
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a specific `config` can be build. The download will happen automatically. Returns `List[datasets.utils.OnlineResource]` of all the files that need to be
present locally before the dataset with a specific `config` can be build. The
download will happen automatically.
Currently, the following `OnlineResource`'s are supported: Currently, the following `OnlineResource`'s are supported:
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL. - `HttpResource`: Used for files that are directly exposed through HTTP(s) and
- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`. only requires the URL.
- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them manually. If the file does not exist, an error will be raised with the supplied instructions. - `GDriveResource`: Used for files that are hosted on GDrive and requires the
GDrive ID as well as the `file_name`.
Although optional in general, all resources used in the built-in datasets should comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the download. You can compute the checksum with system utilities or this snippet: - `ManualDownloadResource`: Used files are not publicly accessible and requires
instructions how to download them manually. If the file does not exist, an
error will be raised with the supplied instructions.
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This
inherits from `ManualDownloadResource`.
Although optional in general, all resources used in the built-in datasets should
comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It
will be automatically checked after the download. You can compute the checksum
with system utilities e.g `sha256-sum`, or this snippet:
```python ```python
import hashlib import hashlib
...@@ -72,35 +107,84 @@ def sha256sum(path, chunk_size=1024 * 1024): ...@@ -72,35 +107,84 @@ def sha256sum(path, chunk_size=1024 * 1024):
### `_make_datapipe(resource_dps, *, config, decoder)` ### `_make_datapipe(resource_dps, *, config, decoder)`
This method is the heart of the dataset that need to transform the raw data into a usable form. A major difference compared to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything with them besides iterating. This method is the heart of the dataset, where we transform the raw data into
a usable form. A major difference compared to the current stable datasets is
that everything is performed through `IterDataPipe`'s. From the perspective of
someone that is working with them rather than on them, `IterDataPipe`'s behave
just as generators, i.e. you can't do anything with them besides iterating.
Of course, there are some common building blocks that should suffice in 95% of the cases. The most used Of course, there are some common building blocks that should suffice in 95% of
the cases. The most used are:
- `Mapper`: Apply a callable to every item in the datapipe. - `Mapper`: Apply a callable to every item in the datapipe.
- `Filter`: Keep only items that satisfy a condition. - `Filter`: Keep only items that satisfy a condition.
- `Demultiplexer`: Split a datapipe into multiple ones. - `Demultiplexer`: Split a datapipe into multiple ones.
- `IterKeyZipper`: Merge two datapipes into one. - `IterKeyZipper`: Merge two datapipes into one.
All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated to add one. See the MNIST or CelebA datasets for example. All of them can be imported `from torchdata.datapipes.iter`. In addition, use
`functools.partial` in case a callable needs extra arguments. If the provided
`IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has
a 1-to-1 correspondence with the return value of `resources()`. In case of
archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive.
Otherwise the datapipe will only contain one of such tuples for the file
specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory
buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that: 1.
If not used carefully, this can easily overflow the host memory, since most
datasets will not fit in completely. 2. This can lead to unnecessarily long
warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping
two datapipes of file handles rather than trying to zip already loaded images.
There are two special datapipes that are not used through their class, but
through the functions `hint_sharding` and `hint_shuffling`. As the name implies
they only hint part in the datapipe graph where sharding and shuffling should
take place, but are no-ops by default. They can be imported from
`torchvision.prototype.datasets.utils._internal` and are required in each
dataset.
Finally, each item in the final datapipe should be a dictionary with `str` keys.
There is no standardization of the names (yet!).
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one of such tuples for the file specified by the resource. ## FAQ
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that: ### How do I start?
1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than trying to zip already loaded images. Get the skeleton of your dataset class ready with all 3 methods. For
`_make_datapipe()`, you can just do `return resources_dp[0]` to get started.
Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically
register the dataset and it will be instantiable via
`datasets.load("mydataset")`. On a separate script, try something like
There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and `hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are required in each dataset. ```py
from torchvision.prototype import datasets
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the names (yet!). dataset = datasets.load("mydataset")
for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _make_datapipe()
break
# Or you can also inspect the sample in a debugger
```
## FAQ This will give you an idea of what the first datapipe in `resources_dp`
contains. You can also do that with `resources_dp[1]` or `resources_dp[2]`
(etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.
### What is the `DatasetType.RAW` and when do I use it? ### What is the `DatasetType.RAW` and when do I use it?
`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values, rather than encoded image files such as `DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values,
`.jpg` or `.png`. This is usually only the case for small datasets, since it requires a lot more disk space. The default decoder `datasets.decoder.raw` is only a sentinel and should not be called directly. The decoding should look something like rather than encoded image files such as `.jpg` or `.png`. This is usually only
the case for small datasets, since it requires a lot more disk space. The
default decoder `datasets.decoder.raw` is only a sentinel and should not be
called directly. The decoding should look something like
```python ```python
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
...@@ -118,10 +202,28 @@ For examples, have a look at the MNIST, CIFAR, or SEMEION datasets. ...@@ -118,10 +202,28 @@ For examples, have a look at the MNIST, CIFAR, or SEMEION datasets.
### How do I handle a dataset that defines many categories? ### How do I handle a dataset that defines many categories?
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be automatically loaded if `categories=` is not set. As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only
be set directly for ten categories or fewer. If more categories are needed, you
In case the categories can be generated from the dataset files, e.g. the dataset follow an image folder approach where each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets passed the `root` path to the resources, but they have to be manually loaded, e.g. `self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. can add a `$NAME.categories` file to the `_builtin` folder in which each line
specifies a category. If `$NAME` matches the name of the dataset (which it
definitively should!) it will be automatically loaded if `categories=` is not
set.
In case the categories can be generated from the dataset files, e.g. the dataset
follows an image folder approach where each folder denotes the name of the
category, the dataset can overwrite the `_generate_categories` method. It gets
passed the `root` path to the resources, but they have to be manually loaded,
e.g. `self.resources(config)[0].load(root)`. The method should return a sequence
of strings representing the category names. To generate the `$NAME.categories`
file, run `python -m torchvision.prototype.datasets.generate_category_files
$NAME`.
### What if a resource file forms an I/O bottleneck? ### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the `decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should return `pathlib.Path` of the preprocessed file or folder. In general, we are ok with small performance hits of iterating archives rather
than their extracted content. However, if the performance hit becomes
significant, the archives can still be decompressed or extracted. To do this,
the `decompress: bool` and `extract: bool` flags can be used for every
`OnlineResource` individually. For more complex cases, each resource also
accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw
file and should return `pathlib.Path` of the preprocessed file or folder.
...@@ -6,6 +6,7 @@ from .coco import Coco ...@@ -6,6 +6,7 @@ from .coco import Coco
from .cub200 import CUB200 from .cub200 import CUB200
from .dtd import DTD from .dtd import DTD
from .fer2013 import FER2013 from .fer2013 import FER2013
from .gtsrb import GTSRB
from .imagenet import ImageNet from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet from .oxford_iiit_pet import OxfordIITPet
......
import io
import pathlib
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
HttpResource,
)
from torchvision.prototype.datasets.utils._internal import (
path_comparator,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
)
from torchvision.prototype.features import Label, BoundingBox
class GTSRB(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"gtsrb",
type=DatasetType.IMAGE,
homepage="https://benchmark.ini.rub.de",
categories=[f"{label:05d}" for label in range(43)],
valid_options=dict(split=("train", "test")),
)
_URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
_URLS = {
"train": f"{_URL_ROOT}GTSRB-Training_fixed.zip",
"test": f"{_URL_ROOT}GTSRB_Final_Test_Images.zip",
"test_ground_truth": f"{_URL_ROOT}GTSRB_Final_Test_GT.zip",
}
_CHECKSUMS = {
"train": "df4144942083645bd60b594de348aa6930126c3e0e5de09e39611630abf8455a",
"test": "48ba6fab7e877eb64eaf8de99035b0aaecfbc279bee23e35deca4ac1d0a837fa",
"test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])]
if config.split == "test":
rsrcs.append(
HttpResource(
self._URLS["test_ground_truth"],
sha256=self._CHECKSUMS["test_ground_truth"],
)
)
return rsrcs
def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.suffix == ".ppm":
return 0
elif path.suffix == ".csv":
return 1
else:
return None
def _collate_and_decode(
self, data: Tuple[Tuple[str, Any], Dict[str, Any]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
(image_path, image_buffer), csv_info = data
label = int(csv_info["ClassId"])
bbox = BoundingBox(
torch.tensor([int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")]),
format="xyxy",
image_size=(int(csv_info["Height"]), int(csv_info["Width"])),
)
return {
"image_path": image_path,
"image": decoder(image_buffer) if decoder else image_buffer,
"label": Label(label, category=self.categories[label]),
"bbox": bbox,
}
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
if config.split == "train":
images_dp, ann_dp = Demultiplexer(
resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
else:
images_dp, ann_dp = resource_dps
images_dp = Filter(images_dp, path_comparator("suffix", ".ppm"))
# The order of the image files in the the .zip archives perfectly match the order of the entries in
# the (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper.
ann_dp = CSVDictParser(ann_dp, delimiter=";")
dp = Zipper(images_dp, ann_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, partial(self._collate_and_decode, decoder=decoder))
return dp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment