Unverified Commit 1ac6e8b9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor and simplify prototype datasets (#5778)



* refactor prototype datasets to inherit from IterDataPipe (#5448)

* refactor prototype datasets to inherit from IterDataPipe

* depend on new architecture

* fix missing file detection

* remove unrelated file

* reinstante decorator for mock registering

* options -> config

* remove passing of info to mock data functions

* refactor categories file generation

* fix imagenet

* fix prototype datasets data loading tests (#5711)

* reenable serialization test

* cleanup

* fix dill test

* trigger CI

* patch DILL_AVAILABLE for pickle serialization

* revert CI changes

* remove dill test and traversable test

* add data loader test

* parametrize over only_datapipe

* draw one sample rather than exhaust data loader

* cleanup

* trigger CI

* migrate VOC prototype dataset (#5743)

* migrate VOC prototype dataset

* cleanup

* revert unrelated mock data changes

* remove categories annotations

* move properties to constructor

* readd homepage

* migrate CIFAR prototype datasets (#5751)

* migrate country211 prototype dataset (#5753)

* migrate CLEVR prototype datsaet (#5752)

* migrate coco prototype (#5473)

* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message

* Migrate PCAM prototype dataset (#5745)

* Port PCAM

* skip_integrity_check

* Update torchvision/prototype/datasets/_builtin/pcam.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Address comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate DTD prototype dataset (#5757)

* Migrate DTD prototype dataset

* Docstring

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate GTSRB prototype dataset (#5746)

* Migrate GTSRB prototype dataset

* ufmt

* Address comments

* Apparently mypy doesn't know that __len__ returns ints. How cute.

* why is the CI not triggered??

* Update torchvision/prototype/datasets/_builtin/gtsrb.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate CelebA prototype dataset (#5750)

* migrate CelebA prototype dataset

* inline split_id

* Migrate Food101 prototype dataset (#5758)

* Migrate Food101 dataset

* Added length

* Update torchvision/prototype/datasets/_builtin/food101.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate Fer2013 prototype dataset (#5759)

* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate EuroSAT prototype dataset (#5760)

* Migrate Semeion prototype dataset (#5761)

* migrate caltech prototype datasets (#5749)

* migrate caltech prototype datasets

* resolve third party dependencies

* Migrate Oxford Pets prototype dataset (#5764)

* Migrate Oxford Pets prototype dataset

* Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate mnist prototype datasets (#5480)

* migrate MNIST prototype datasets

* Update torchvision/prototype/datasets/_builtin/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Migrate Stanford Cars prototype dataset (#5767)

* Migrate Stanford Cars prototype dataset

* Address comments

* fix category file generation (#5770)

* fix category file generation

* revert unrelated change

* revert unrelated change

* migrate cub200 prototype dataset (#5765)

* migrate cub200 prototype dataset

* address comments

* fix category-file-generation

* Migrate USPS prototype dataset (#5771)

* migrate SBD prototype dataset (#5772)

* migrate SBD prototype dataset

* reuse categories

* Migrate SVHN prototype dataset (#5769)

* add test to enforce __len__ is working on prototype datasets (#5742)

* reactivate special dataset tests

* add missing annotation

* Cleanup prototype dataset implementation (#5774)

* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* update prototype dataset README (#5777)

* update prototype dataset README

* fix header level

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 5f74f031
This diff is collapsed.
......@@ -7,9 +7,10 @@ import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
......@@ -42,14 +43,24 @@ def test_coverage():
@pytest.mark.filterwarnings("error")
class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
def test_info(self, name):
try:
info = datasets.info(name)
except ValueError:
raise AssertionError("No info available.") from None
if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
raise AssertionError("Info should be a dictionary with string keys.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
......@@ -76,24 +87,7 @@ class TestCommon:
dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0
for _ in dataset:
num_samples += 1
assert num_samples == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
assert len(list(dataset)) == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
......@@ -116,14 +110,36 @@ class TestCommon:
next(iter(dataset.map(transforms.Identity())))
@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset, only_datapipe=only_datapipe)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
pickle.dumps(dataset)
@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
dl = DataLoader(
dataset,
batch_size=2,
num_workers=num_workers,
collate_fn=lambda batch: batch,
)
next(iter(dl))
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
......@@ -132,7 +148,6 @@ class TestCommon:
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
......@@ -160,6 +175,13 @@ class TestCommon:
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
assert len(dataset) > 0
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
......@@ -186,7 +208,7 @@ class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config.split != "train":
if config["split"] != "train":
return
dataset_mock.prepare(test_home, config)
......
import unittest.mock
import pytest
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)
class TestFrozenMapping:
@pytest.mark.parametrize(
("args", "kwargs"),
[
pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"),
pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"),
pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"),
],
)
def test_instantiation(self, args, kwargs):
FrozenMapping(*args, **kwargs)
def test_unhashable_items(self):
with pytest.raises(TypeError, match="unhashable type"):
FrozenMapping(foo=[])
def test_getitem(self):
options = dict(foo="bar", baz=1)
config = FrozenMapping(options)
for key, value in options.items():
assert config[key] == value
def test_getitem_unknown(self):
with pytest.raises(KeyError):
FrozenMapping()["unknown"]
def test_iter(self):
options = dict(foo="bar", baz=1)
assert set(iter(FrozenMapping(options))) == set(options.keys())
def test_len(self):
options = dict(foo="bar", baz=1)
assert len(FrozenMapping(options)) == len(options)
def test_immutable_setitem(self):
frozen_mapping = FrozenMapping()
with pytest.raises(RuntimeError, match="immutable"):
frozen_mapping["foo"] = "bar"
def test_immutable_delitem(
self,
):
frozen_mapping = FrozenMapping(foo="bar")
with pytest.raises(RuntimeError, match="immutable"):
del frozen_mapping["foo"]
def test_eq(self):
options = dict(foo="bar", baz=1)
assert FrozenMapping(options) == FrozenMapping(options)
def test_ne(self):
options1 = dict(foo="bar", baz=1)
options2 = options1.copy()
options2["baz"] += 1
assert FrozenMapping(options1) != FrozenMapping(options2)
def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenMapping(options))
assert isinstance(output, str)
for key, value in options.items():
assert str(key) in output and str(value) in output
class TestFrozenBunch:
def test_getattr(self):
options = dict(foo="bar", baz=1)
config = FrozenBunch(options)
for key, value in options.items():
assert getattr(config, key) == value
def test_getattr_unknown(self):
with pytest.raises(AttributeError, match="no attribute 'unknown'"):
datasets.utils.DatasetConfig().unknown
def test_immutable_setattr(self):
frozen_bunch = FrozenBunch()
with pytest.raises(RuntimeError, match="immutable"):
frozen_bunch.foo = "bar"
def test_immutable_delattr(
self,
):
frozen_bunch = FrozenBunch(foo="bar")
with pytest.raises(RuntimeError, match="immutable"):
del frozen_bunch.foo
def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenBunch(options))
assert isinstance(output, str)
assert output.startswith("FrozenBunch")
for key, value in options.items():
assert f"{key}={value}" in output
class TestDatasetInfo:
@pytest.fixture
def info(self):
return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz")))
def test_default_config(self, info):
valid_options = info._valid_options
default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()})
assert info.default_config == default_config
@pytest.mark.parametrize(
("valid_options", "options", "expected_error_msg"),
[
(dict(), dict(any_option=None), "does not take any options"),
(dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"),
(dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"),
],
)
def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg):
info = make_minimal_dataset_info(valid_options=valid_options)
with pytest.raises(ValueError, match=expected_error_msg):
info.make_config(**options)
def test_check_dependencies(self):
dependency = "fake_dependency"
info = make_minimal_dataset_info(dependencies=(dependency,))
with pytest.raises(ModuleNotFoundError, match=dependency):
info.check_dependencies()
def test_repr(self, info):
output = repr(info)
assert isinstance(output, str)
assert "DatasetInfo" in output
for key, value in info._valid_options.items():
assert f"{key}={str(value)[1:-1]}" in output
@pytest.mark.parametrize("optional_info", ("citation", "homepage", "license"))
def test_repr_optional_info(self, optional_info):
sentinel = "sentinel"
info = make_minimal_dataset_info(**{optional_info: sentinel})
assert f"{optional_info}={sentinel}" in repr(info)
class TestDataset:
class DatasetMock(datasets.utils.Dataset):
def __init__(self, info=None, *, resources=None):
self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test")))
self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources
self._make_datapipe = unittest.mock.Mock()
super().__init__()
def _make_info(self):
return self._info
def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
def _make_datapipe(self, resource_dps, *, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
def test_name(self):
name = "sentinel"
dataset = self.DatasetMock(make_minimal_dataset_info(name=name))
assert dataset.name == name
def test_default_config(self):
sentinel = "sentinel"
dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train"))))
assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel)
@pytest.mark.parametrize(
("config", "kwarg"),
[
pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"),
pytest.param(DatasetMock().default_config, None, id="default"),
],
)
def test_load_config(self, config, kwarg):
dataset = self.DatasetMock()
dataset.load("", config=kwarg)
dataset.resources.assert_called_with(config)
_, call_kwargs = dataset._make_datapipe.call_args
assert call_kwargs["config"] == config
def test_missing_dependencies(self):
dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.load("root")
def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["load"])
sentinel = object()
resource_mock.load.return_value = sentinel
dataset = self.DatasetMock(resources=[resource_mock])
root = "root"
dataset.load(root)
(call_args, _) = resource_mock.load.call_args
assert call_args[0] == root
(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel
......@@ -5,7 +5,7 @@ import pytest
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
......@@ -101,3 +101,21 @@ class TestHttpResource:
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))
def _resources(self):
pass
def _datapipe(self, resource_dps):
pass
def __len__(self):
pass
with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
......@@ -10,5 +10,6 @@ from . import utils
from ._home import home
# Load this last, since some parts depend on the above being loaded first
from ._api import list_datasets, info, load # usort: skip
from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip
from ._folder import from_data_folder, from_image_folder
from ._builtin import *
import os
from typing import Any, Dict, List
import pathlib
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion
from . import _builtin
DATASETS: Dict[str, Dataset] = {}
T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset])
BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}
def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset
def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]:
def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_INFOS[name] = fn()
return fn
for name, obj in _builtin.__dict__.items():
if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
register(obj())
return wrapper
BUILTIN_DATASETS = {}
def register_dataset(name: str) -> Callable[[D], D]:
def wrapper(dataset_cls: D) -> D:
BUILTIN_DATASETS[name] = dataset_cls
return dataset_cls
return wrapper
def list_datasets() -> List[str]:
return sorted(DATASETS.keys())
return sorted(BUILTIN_DATASETS.keys())
def find(name: str) -> Dataset:
def find(dct: Dict[str, T], name: str) -> T:
name = name.lower()
try:
return DATASETS[name]
return dct[name]
except KeyError as error:
raise ValueError(
add_suggestion(
f"Unknown dataset '{name}'.",
word=name,
possibilities=DATASETS.keys(),
possibilities=dct.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
......@@ -41,19 +52,14 @@ def find(name: str) -> Dataset:
) from error
def info(name: str) -> DatasetInfo:
return find(name).info
def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)
def load(
name: str,
*,
skip_integrity_check: bool = False,
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)
config = dataset.info.make_config(**options)
root = os.path.join(home(), dataset.name)
if root is None:
root = pathlib.Path(home()) / name
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)
return dataset_cls(root, **config)
......@@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below.
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin`
that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that
module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be
discussed in detail below:
module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in
detail below:
```python
from typing import Any, Dict, List
import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from .._api import register_dataset, register_info
NAME = "my-dataset"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
...
)
@register_dataset(NAME)
class MyDataset(Dataset):
def _make_info(self) -> DatasetInfo:
def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None:
...
super().__init__(root, skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
...
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]:
...
def __len__(self) -> int:
...
```
### `_make_info(self)`
In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a
dictionary of static information. The most common use case is to provide human-readable categories.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
The `DatasetInfo` carries static information about the dataset. There are two required fields:
Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively.
- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain
lowercase characters.
### `__init__(self, root, *, ..., skip_integrity_check = False)`
There are more optional parameters that can be passed:
Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke
the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
an underscore.
- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their
availability will be automatically checked if a user tries to load the dataset. Within the implementation, import
these packages lazily to avoid missing dependencies at import time.
- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the
corresponding label returned in the dataset samples.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[Any]]`.
The options are accessible through the `config` namespace in the other two functions. First value of the sequence is
taken as default if the user passes no option to `torchvision.prototype.datasets.load()`.
If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically
checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
avoid missing dependencies at import time.
## `resources(self, config)`
### `_resources(self)`
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a
specific `config` can be build. The download will happen automatically.
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be
build. The download will happen automatically.
Currently, the following `OnlineResource`'s are supported:
......@@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024):
print(checksum.hexdigest())
```
### `_make_datapipe(resource_dps, *, config)`
### `_datapipe(self, resource_dps)`
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone
......@@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f
needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one
`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
of such tuples for the file specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and
`Grouper`. There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since
most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered
that is only needed at runtime.
`Grouper`. There are two issues with that:
1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take
place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are
required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal`
and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!).
### `__len__`
This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the
readability. For example, `1_281_167` vs. `1281167`.
If there are only two different numbers, a simple `if` / `else` is fine:
```py
def __len__(self):
return 12_345 if self._split == "train" else 6_789
```
If there are more options, using a dictionary usually is the most readable option:
```py
def __len__(self):
return {
"train": 3,
"val": 2,
"test": 1,
}[self._split]
```
If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
```py
def __len__(self):
return {
("train", "bar"): 4,
("train", "baz"): 3,
("test", "bar"): 2,
("test", "baz"): 1,
}[(self._split, self._foo)]
```
The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way
is to define a dummy method like
```py
def __len__(self):
return 1
```
and only fill it with the correct data if the implementation is otherwise finished.
[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples.
## Tests
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
This mock-up should resemble the original data as close as necessary, while containing only few examples.
To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the
same name as you have defined in `_make_config()` (if the name includes hyphens `-`, replace them with underscores `_`)
and decorate it with `@register_mock`:
same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function".
Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset
will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options,
you can use the `combinations_grid()` helper function, e.g.
`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`.
In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass
the `name` parameter to `@register_mock`
```py
# this is defined in torchvision/prototype/datasets/_builtin
@register_dataset("my-dataset")
class MyDataset(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"my-dataset",
...
)
@register_mock
def my_dataset(info, root, config):
...
@register_mock(name="my-dataset", configs=...)
def my_dataset(root, config):
...
```
The function receives three arguments:
The mock data function receives two arguments:
- `info`: The return value of `_make_info()`.
- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data
needs to be placed.
- `config`: The configuration to generate the data for. This is the same value that `_make_datapipe()` receives.
- `config`: The configuration to generate the data for. This is one of the dictionaries defined in
`@register_mock(configs=...)`
The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if
the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current
`config`. Although this seems odd at first, this is important. Consider the following original data setup:
the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
the current `config`. Although this seems odd at first, this is important. Consider the following original data setup:
```
root
......@@ -167,9 +234,8 @@ root
For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to
load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is
present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in
`_make_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data
for the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real
data.
`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for
the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data.
For datasets that are ported from the old API, we already have some mock data in
[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there
......@@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences
- `tmp_dir` corresponds to `root`, but is a `str` rather than a
[`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like
`folder = pathlib.Path(tmp_dir)`. This is not needed.
- Although both parameters are called `config`, the value in the new tests is a namespace. Thus, please use `config.foo`
over `config["foo"]` to enhance readability.
- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the
new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files
specified in the dataset.
......@@ -196,9 +260,9 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets
### How do I start?
Get the skeleton of your dataset class ready with all 3 methods. For `_make_datapipe()`, you can just do
Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do
`return resources_dp[0]` to get started. Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset and it will be
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be
instantiable via `datasets.load("mydataset")`. On a separate script, try something like
```py
......@@ -206,7 +270,7 @@ from torchvision.prototype import datasets
dataset = datasets.load("mydataset")
for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _make_datapipe()
print(sample) # this is the content of an item in datapipe returned by _datapipe()
break
# Or you can also inspect the sample in a debugger
```
......@@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format.
### How do I handle a dataset that defines many categories?
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or
fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line
specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be
automatically loaded if `categories=` is not set.
As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more
categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a
category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file`
function and pass it `$NAME`.
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets
passed the `root` path to the resources, but they have to be manually loaded, e.g.
`self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names.
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method
should return a sequence of strings representing the category names. In the method body, you'll have to manually load
the resources, e.g.
```py
resources = self._resources()
dp = resources[0].load(self._root)
```
Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that.
To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
### What if a resource file forms an I/O bottleneck?
......@@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed.
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
accepts `"decompress"` and `"extract"` to handle these common scenarios.
### How do I compute the number of samples?
Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
than to iterate over the dataset and count the number of samples:
```py
import itertools
from torchvision.prototype import datasets
def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
configs = combinations_grid(split=("train", "test"), foo=("bar", "baz"))
for config in configs:
dataset = datasets.load("my-dataset", **config)
num_samples = 0
for _ in dataset:
num_samples += 1
print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples)
```
To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
files.
......@@ -12,7 +12,7 @@ from .food101 import Food101
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
......
import pathlib
import re
from typing import Any, Dict, List, Tuple, BinaryIO
from typing import Any, Dict, List, Tuple, BinaryIO, Union
import numpy as np
from torchdata.datapipes.iter import (
......@@ -9,26 +9,46 @@ from torchdata.datapipes.iter import (
Filter,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
hint_sharding,
hint_shuffling,
read_categories_file,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
from .._api import register_dataset, register_info
@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech101"))
@register_dataset("caltech101")
class Caltech101(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
- **dependencies**:
- <scipy `https://scipy.org/`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech101_info()["categories"]
super().__init__(
root,
dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
skip_integrity_check=skip_integrity_check,
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
......@@ -88,7 +108,7 @@ class Caltech101(Dataset):
ann = read_mat(ann_buffer)
return dict(
label=Label.from_category(category, categories=self.categories),
label=Label.from_category(category, categories=self._categories),
image_path=image_path,
image=image,
ann_path=ann_path,
......@@ -98,12 +118,7 @@ class Caltech101(Dataset):
contour=_Feature(ann["obj_contour"].T),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image)
......@@ -122,23 +137,39 @@ class Caltech101(Dataset):
)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
def __len__(self) -> int:
return 8677
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech256"))
@register_dataset("caltech256")
class Caltech256(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech256_info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
......@@ -156,25 +187,23 @@ class Caltech256(Dataset):
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
def __len__(self) -> int:
return 30607
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)]
import csv
import functools
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union
from torchdata.datapipes.iter import (
IterDataPipe,
......@@ -11,8 +11,6 @@ from torchdata.datapipes.iter import (
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
GDriveResource,
OnlineResource,
)
......@@ -25,6 +23,7 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
from .._api import register_dataset, register_info
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
......@@ -60,15 +59,32 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
yield line.pop("image_id"), line
NAME = "celeba"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CelebA(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
)
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
......@@ -101,14 +117,13 @@ class CelebA(Dataset):
)
return [splits, images, identities, attributes, bounding_boxes, landmarks]
_SPLIT_ID_TO_NAME = {
"0": "train",
"1": "val",
"2": "test",
}
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool:
split_id = {
"train": "0",
"val": "1",
"test": "2",
}[self._split]
return data[1]["split_id"] == split_id
def _prepare_sample(
self,
......@@ -145,16 +160,11 @@ class CelebA(Dataset):
},
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = Filter(splits_dp, self._filter_split)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
......@@ -186,3 +196,10 @@ class CelebA(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 162_770,
"val": 19_867,
"test": 19_962,
}[self._split]
import abc
import functools
import io
import pathlib
import pickle
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union
import numpy as np
from torchdata.datapipes.iter import (
......@@ -11,20 +10,17 @@ from torchdata.datapipes.iter import (
Filter,
Mapper,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
path_comparator,
hint_sharding,
read_categories_file,
)
from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
......@@ -44,19 +40,23 @@ class _CifarBase(Dataset):
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
_categories: List[str]
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]:
def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
pass
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
type(self).__name__.lower(),
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(split=("train", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
......@@ -72,52 +72,72 @@ class _CifarBase(Dataset):
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self.categories),
label=Label(category_idx, categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
dp = Filter(dp, self._is_data_file)
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
def __len__(self) -> int:
return 50_000 if self._split == "train" else 10_000
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
@register_info("cifar10")
def _cifar10_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar10"))
@register_dataset("cifar10")
class Cifar10(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
_categories = _cifar10_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if split == "train" else "test")
return path.name.startswith("data" if self._split == "train" else "test")
@register_info("cifar100")
def _cifar100_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar100"))
@register_dataset("cifar100")
class Cifar100(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
_categories = _cifar100_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == split
return path.name == self._split
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
......@@ -19,16 +13,30 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "clevr"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CLEVR(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"clevr",
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
valid_options=dict(split=("train", "val", "test")),
)
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
......@@ -61,12 +69,7 @@ class CLEVR(Dataset):
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
......@@ -76,12 +79,12 @@ class CLEVR(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if config.split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
if self._split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)
......@@ -97,3 +100,6 @@ class CLEVR(Dataset):
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 70_000 if self._split == "train" else 15_000
import functools
import pathlib
import re
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union
import torch
from torchdata.datapipes.iter import (
......@@ -16,43 +16,65 @@ from torchdata.datapipes.iter import (
UnBatcher,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
Dataset,
)
from torchvision.prototype.datasets.utils._internal import (
MappingIterator,
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
getitem,
read_categories_file,
path_accessor,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
from .._api import register_dataset, register_info
NAME = "coco"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*read_categories_file(NAME))
return dict(categories=categories, super_categories=super_categories)
@register_dataset(NAME)
class Coco(Dataset):
def _make_info(self) -> DatasetInfo:
name = "coco"
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
return DatasetInfo(
name,
dependencies=("pycocotools",),
categories=categories,
homepage="https://cocodataset.org/",
valid_options=dict(
split=("train", "val"),
year=("2017", "2014"),
annotations=(*self._ANN_DECODERS.keys(), None),
),
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))),
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2017",
annotations: Optional[str] = "instances",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
)
info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))
super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
_IMAGES_CHECKSUMS = {
......@@ -69,14 +91,14 @@ class Coco(Dataset):
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip",
sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)],
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
)
meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip",
sha256=self._META_CHECKSUMS[config.year],
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[self._year],
)
return [images, meta]
......@@ -110,10 +132,8 @@ class Coco(Dataset):
format="xywh",
image_size=image_size,
),
labels=Label(labels, categories=self.categories),
super_categories=[
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
],
labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)
......@@ -134,9 +154,14 @@ class Coco(Dataset):
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)
def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool:
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations)
return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
......@@ -157,38 +182,26 @@ class Coco(Dataset):
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
*,
annotations: str,
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data
sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps
if config.annotations is None:
if self._annotations is None:
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image)
meta_dp = Filter(
meta_dp,
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
......@@ -216,7 +229,6 @@ class Coco(Dataset):
ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
anns_dp,
images_dp,
......@@ -224,18 +236,24 @@ class Coco(Dataset):
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]
return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations))
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
config = self.default_config
resources = self.resources(config)
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
resources = self._resources()
dp = resources[1].load(root)
dp = Filter(
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_meta_files)
dp = JsonParser(dp)
_, meta = next(iter(dp))
......
import pathlib
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
path_comparator,
hint_sharding,
hint_shuffling,
read_categories_file,
)
from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "country211"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Country211(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"country211",
homepage="https://github.com/openai/CLIP/blob/main/data/country211.md",
valid_options=dict(split=("train", "val", "test")),
)
"""
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"""
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
self._split_folder_name = "valid" if split == "val" else split
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"https://openaipublic.azureedge.net/clip/data/country211.tgz",
......@@ -23,17 +49,11 @@ class Country211(Dataset):
)
]
_SPLIT_NAME_MAPPER = {
"train": "train",
"val": "valid",
"test": "test",
}
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self.categories),
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
......@@ -41,16 +61,21 @@ class Country211(Dataset):
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
return pathlib.Path(data[0]).parent.parent.name == split
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split]))
dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name))
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
dp = resources[0].load(root)
def __len__(self) -> int:
return {
"train": 31_650,
"val": 10_550,
"test": 21_100,
}[self._split]
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
import csv
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union
from torchdata.datapipes.iter import (
IterDataPipe,
......@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
......@@ -27,27 +25,52 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
getitem,
path_comparator,
read_categories_file,
path_accessor,
)
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
from .._api import register_dataset, register_info
csv.register_dialect("cub200", delimiter=" ")
NAME = "cub200"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class CUB200(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cub200",
homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html",
dependencies=("scipy",),
valid_options=dict(
split=("train", "test"),
year=("2011", "2010"),
),
"""
- **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2011",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._year = self._verify_str_arg(year, "year", ("2010", "2011"))
self._categories = _info()["categories"]
super().__init__(
root,
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
# dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.year == "2011":
def _resources(self) -> List[OnlineResource]:
if self._year == "2011":
archive = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
......@@ -59,7 +82,7 @@ class CUB200(Dataset):
preprocess="decompress",
)
return [archive, segmentations]
else: # config.year == "2010"
else: # self._year == "2010"
split = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
......@@ -90,12 +113,12 @@ class CUB200(Dataset):
else:
return None
def _2011_filter_split(self, row: List[str], *, split: str) -> bool:
def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row
return {
"0": "test",
"1": "train",
}[split_id] == split
}[split_id] == self._split
def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
......@@ -149,17 +172,12 @@ class CUB200(Dataset):
return dict(
prepare_ann_fn(anns_data, image.image_size),
image=image,
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories),
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
prepare_ann_fn: Callable
if config.year == "2011":
if self._year == "2011":
archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
......@@ -171,7 +189,7 @@ class CUB200(Dataset):
)
split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split))
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.get)
......@@ -188,10 +206,10 @@ class CUB200(Dataset):
)
prepare_ann_fn = self._2011_prepare_ann
else: # config.year == "2010"
else: # self._year == "2010"
split_dp, images_dp, anns_dp = resource_dps
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key)
......@@ -217,11 +235,19 @@ class CUB200(Dataset):
)
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(year="2011")
resources = self.resources(config)
def __len__(self) -> int:
return {
("train", "2010"): 3_000,
("test", "2010"): 3_033,
("train", "2011"): 5_994,
("test", "2011"): 5_794,
}[(self._split, self._year)]
def _generate_categories(self) -> List[str]:
self._year = "2011"
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")
......
import enum
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
......@@ -15,10 +13,16 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
path_comparator,
getitem,
read_categories_file,
hint_shuffling,
)
from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "dtd"
class DTDDemux(enum.IntEnum):
SPLIT = 0
......@@ -26,18 +30,36 @@ class DTDDemux(enum.IntEnum):
IMAGES = 2
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"dtd",
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
fold=tuple(str(fold) for fold in range(1, 11)),
),
)
"""DTD Dataset.
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
"""
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
fold: int = 1,
skip_validation_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
if not (1 <= fold <= 10):
raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}")
self._fold = fold
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_validation_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
......@@ -71,24 +93,19 @@ class DTD(Dataset):
return dict(
joint_categories={category for category in joint_categories if category},
label=Label.from_category(category, categories=self.categories),
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
......@@ -114,10 +131,13 @@ class DTD(Dataset):
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
def __len__(self) -> int:
return 1_880 # All splits have the same length
import pathlib
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
class EuroSAT(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"eurosat",
homepage="https://github.com/phelber/eurosat",
categories=(
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial," "Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake",
),
NAME = "eurosat"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
categories=(
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial," "Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake",
)
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
@register_dataset(NAME)
class EuroSAT(Dataset):
"""EuroSAT Dataset.
homepage="https://github.com/phelber/eurosat",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
......@@ -37,15 +50,16 @@ class EuroSAT(Dataset):
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self.categories),
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 27_000
from typing import Any, Dict, List, cast
import pathlib
from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
KaggleDownloadResource,
)
......@@ -15,26 +14,40 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
NAME = "fer2013"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"))
@register_dataset(NAME)
class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)
"""FER 2013 Dataset
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[self._split],
)
return [archive]
......@@ -43,17 +56,15 @@ class FER2013(Dataset):
return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self.categories) if label_id is not None else None,
label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 28_709 if self._split == "train" else 3_589
from pathlib import Path
from typing import Any, Tuple, List, Dict, Optional, BinaryIO
from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union
from torchdata.datapipes.iter import (
IterDataPipe,
......@@ -9,26 +9,41 @@ from torchdata.datapipes.iter import (
Demultiplexer,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
hint_sharding,
path_comparator,
getitem,
INFINITE_BUFFER_SIZE,
read_categories_file,
)
from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "food101"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Food101(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"food101",
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
valid_options=dict(split=("train", "test")),
)
"""Food 101 dataset
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
"""
def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
......@@ -49,7 +64,7 @@ class Food101(Dataset):
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data
return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self.categories),
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
......@@ -58,17 +73,12 @@ class Food101(Dataset):
path = Path(data[0])
return path.relative_to(path.parents[1]).with_suffix("").as_posix()
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, split_dp = Demultiplexer(
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
......@@ -83,9 +93,12 @@ class Food101(Dataset):
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: Path) -> List[str]:
resources = self.resources(self.default_config)
dp = resources[0].load(root)
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = LineReader(dp, decode=True, return_path=False)
return list(dp)
def __len__(self) -> int:
return 75_750 if self._split == "train" else 25_250
import pathlib
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
HttpResource,
)
......@@ -17,15 +15,31 @@ from torchvision.prototype.datasets.utils._internal import (
)
from torchvision.prototype.features import Label, BoundingBox, EncodedImage
from .._api import register_dataset, register_info
NAME = "gtsrb"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
categories=[f"{label:05d}" for label in range(43)],
)
@register_dataset(NAME)
class GTSRB(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"gtsrb",
homepage="https://benchmark.ini.rub.de",
categories=[f"{label:05d}" for label in range(43)],
valid_options=dict(split=("train", "test")),
)
"""GTSRB Dataset
homepage="https://benchmark.ini.rub.de"
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
_URLS = {
......@@ -39,10 +53,10 @@ class GTSRB(Dataset):
"test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])]
def _resources(self) -> List[OnlineResource]:
rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])]
if config.split == "test":
if self._split == "test":
rsrcs.append(
HttpResource(
self._URLS["test_ground_truth"],
......@@ -74,14 +88,12 @@ class GTSRB(Dataset):
return {
"path": path,
"image": EncodedImage.from_file(buffer),
"label": Label(label, categories=self.categories),
"label": Label(label, categories=self._categories),
"bounding_box": bounding_box,
}
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
if config.split == "train":
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
if self._split == "train":
images_dp, ann_dp = Demultiplexer(
resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
......@@ -98,3 +110,6 @@ class GTSRB(Dataset):
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 26_640 if self._split == "train" else 12_630
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment