Unverified Commit 1ac6e8b9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Refactor and simplify prototype datasets (#5778)



* refactor prototype datasets to inherit from IterDataPipe (#5448)

* refactor prototype datasets to inherit from IterDataPipe

* depend on new architecture

* fix missing file detection

* remove unrelated file

* reinstante decorator for mock registering

* options -> config

* remove passing of info to mock data functions

* refactor categories file generation

* fix imagenet

* fix prototype datasets data loading tests (#5711)

* reenable serialization test

* cleanup

* fix dill test

* trigger CI

* patch DILL_AVAILABLE for pickle serialization

* revert CI changes

* remove dill test and traversable test

* add data loader test

* parametrize over only_datapipe

* draw one sample rather than exhaust data loader

* cleanup

* trigger CI

* migrate VOC prototype dataset (#5743)

* migrate VOC prototype dataset

* cleanup

* revert unrelated mock data changes

* remove categories annotations

* move properties to constructor

* readd homepage

* migrate CIFAR prototype datasets (#5751)

* migrate country211 prototype dataset (#5753)

* migrate CLEVR prototype datsaet (#5752)

* migrate coco prototype (#5473)

* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message

* Migrate PCAM prototype dataset (#5745)

* Port PCAM

* skip_integrity_check

* Update torchvision/prototype/datasets/_builtin/pcam.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Address comments
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate DTD prototype dataset (#5757)

* Migrate DTD prototype dataset

* Docstring

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate GTSRB prototype dataset (#5746)

* Migrate GTSRB prototype dataset

* ufmt

* Address comments

* Apparently mypy doesn't know that __len__ returns ints. How cute.

* why is the CI not triggered??

* Update torchvision/prototype/datasets/_builtin/gtsrb.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate CelebA prototype dataset (#5750)

* migrate CelebA prototype dataset

* inline split_id

* Migrate Food101 prototype dataset (#5758)

* Migrate Food101 dataset

* Added length

* Update torchvision/prototype/datasets/_builtin/food101.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate Fer2013 prototype dataset (#5759)

* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Migrate EuroSAT prototype dataset (#5760)

* Migrate Semeion prototype dataset (#5761)

* migrate caltech prototype datasets (#5749)

* migrate caltech prototype datasets

* resolve third party dependencies

* Migrate Oxford Pets prototype dataset (#5764)

* Migrate Oxford Pets prototype dataset

* Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* migrate mnist prototype datasets (#5480)

* migrate MNIST prototype datasets

* Update torchvision/prototype/datasets/_builtin/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Migrate Stanford Cars prototype dataset (#5767)

* Migrate Stanford Cars prototype dataset

* Address comments

* fix category file generation (#5770)

* fix category file generation

* revert unrelated change

* revert unrelated change

* migrate cub200 prototype dataset (#5765)

* migrate cub200 prototype dataset

* address comments

* fix category-file-generation

* Migrate USPS prototype dataset (#5771)

* migrate SBD prototype dataset (#5772)

* migrate SBD prototype dataset

* reuse categories

* Migrate SVHN prototype dataset (#5769)

* add test to enforce __len__ is working on prototype datasets (#5742)

* reactivate special dataset tests

* add missing annotation

* Cleanup prototype dataset implementation (#5774)

* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* update prototype dataset README (#5777)

* update prototype dataset README

* fix header level

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 5f74f031
...@@ -10,6 +10,7 @@ import lzma ...@@ -10,6 +10,7 @@ import lzma
import pathlib import pathlib
import pickle import pickle
import random import random
import unittest.mock
import warnings import warnings
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict, Counter from collections import defaultdict, Counter
...@@ -18,11 +19,11 @@ import numpy as np ...@@ -18,11 +19,11 @@ import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype.datasets._api import find from torchvision.prototype import datasets
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ()) make_scalar = functools.partial(make_tensor, ())
...@@ -32,13 +33,11 @@ __all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"] ...@@ -32,13 +33,11 @@ __all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
class DatasetMock: class DatasetMock:
def __init__(self, name, mock_data_fn): def __init__(self, name, *, mock_data_fn, configs):
self.dataset = find(name) # FIXME: error handling for unknown names
self.info = self.dataset.info self.name = name
self.name = self.info.name
self.mock_data_fn = mock_data_fn self.mock_data_fn = mock_data_fn
self.configs = self.info._configs self.configs = configs
def _parse_mock_info(self, mock_info): def _parse_mock_info(self, mock_info):
if mock_info is None: if mock_info is None:
...@@ -67,10 +66,13 @@ class DatasetMock: ...@@ -67,10 +66,13 @@ class DatasetMock:
root = home / self.name root = home / self.name
root.mkdir(exist_ok=True) root.mkdir(exist_ok=True)
mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
available_file_names = {path.name for path in root.glob("*")} available_file_names = {path.name for path in root.glob("*")}
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
missing_file_names = required_file_names - available_file_names missing_file_names = required_file_names - available_file_names
if missing_file_names: if missing_file_names:
raise pytest.UsageError( raise pytest.UsageError(
...@@ -125,10 +127,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): ...@@ -125,10 +127,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
DATASET_MOCKS = {} DATASET_MOCKS = {}
def register_mock(fn): def register_mock(name=None, *, configs):
name = fn.__name__.replace("_", "-") def wrapper(mock_data_fn):
DATASET_MOCKS[name] = DatasetMock(name, fn) nonlocal name
return fn if name is None:
name = mock_data_fn.__name__
DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs)
return mock_data_fn
return wrapper
class MNISTMockData: class MNISTMockData:
...@@ -206,58 +214,64 @@ class MNISTMockData: ...@@ -206,58 +214,64 @@ class MNISTMockData:
return num_samples return num_samples
@register_mock def mnist(root, config):
def mnist(info, root, config): prefix = "train" if config["split"] == "train" else "t10k"
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTMockData.generate( return MNISTMockData.generate(
root, root,
num_categories=len(info.categories), num_categories=10,
images_file=images_file, images_file=f"{prefix}-images-idx3-ubyte.gz",
labels_file=labels_file, labels_file=f"{prefix}-labels-idx1-ubyte.gz",
) )
DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) DATASET_MOCKS.update(
{
name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test")))
for name in ["mnist", "fashionmnist", "kmnist"]
}
)
@register_mock @register_mock(
def emnist(info, root, config): configs=combinations_grid(
# The image sets that merge some lower case letters in their respective upper case variant, still use dense split=("train", "test"),
# labels in the data files. Thus, num_categories != len(categories) there. image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"),
num_categories = defaultdict(
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
) )
)
def emnist(root, config):
num_samples_map = {} num_samples_map = {}
file_names = set() file_names = set()
for config_ in info._configs: for split, image_set in itertools.product(
prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}" ("train", "test"),
("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"),
):
prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}"
images_file = f"{prefix}-images-idx3-ubyte.gz" images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file}) file_names.update({images_file, labels_file})
num_samples_map[config_] = MNISTMockData.generate( num_samples_map[(split, image_set)] = MNISTMockData.generate(
root, root,
num_categories=num_categories[config_.image_set], # The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62,
images_file=images_file, images_file=images_file,
labels_file=labels_file, labels_file=labels_file,
) )
make_zip(root, "emnist-gzip.zip", *file_names) make_zip(root, "emnist-gzip.zip", *file_names)
return num_samples_map[config] return num_samples_map[(config["split"], config["image_set"])]
@register_mock @register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist")))
def qmnist(info, root, config): def qmnist(root, config):
num_categories = len(info.categories) num_categories = 10
if config.split == "train": if config["split"] == "train":
num_samples = num_samples_gen = num_categories + 2 num_samples = num_samples_gen = num_categories + 2
prefix = "qmnist-train" prefix = "qmnist-train"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
elif config.split.startswith("test"): elif config["split"].startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty. # more than 10000 images for the dataset to not be empty.
num_samples_gen = 10001 num_samples_gen = 10001
...@@ -265,11 +279,11 @@ def qmnist(info, root, config): ...@@ -265,11 +279,11 @@ def qmnist(info, root, config):
"test": num_samples_gen, "test": num_samples_gen,
"test10k": min(num_samples_gen, 10_000), "test10k": min(num_samples_gen, 10_000),
"test50k": num_samples_gen - 10_000, "test50k": num_samples_gen - 10_000,
}[config.split] }[config["split"]]
prefix = "qmnist-test" prefix = "qmnist-test"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
else: # config.split == "nist" else: # config["split"] == "nist"
num_samples = num_samples_gen = num_categories + 3 num_samples = num_samples_gen = num_categories + 3
prefix = "xnist" prefix = "xnist"
suffix = ".xz" suffix = ".xz"
...@@ -326,8 +340,8 @@ class CIFARMockData: ...@@ -326,8 +340,8 @@ class CIFARMockData:
make_tar(root, name, folder, compression="gz") make_tar(root, name, folder, compression="gz")
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def cifar10(info, root, config): def cifar10(root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)] train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"] test_files = ["test_batch"]
...@@ -341,11 +355,11 @@ def cifar10(info, root, config): ...@@ -341,11 +355,11 @@ def cifar10(info, root, config):
labels_key="labels", labels_key="labels",
) )
return len(train_files if config.split == "train" else test_files) return len(train_files if config["split"] == "train" else test_files)
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def cifar100(info, root, config): def cifar100(root, config):
train_files = ["train"] train_files = ["train"]
test_files = ["test"] test_files = ["test"]
...@@ -359,11 +373,11 @@ def cifar100(info, root, config): ...@@ -359,11 +373,11 @@ def cifar100(info, root, config):
labels_key="fine_labels", labels_key="fine_labels",
) )
return len(train_files if config.split == "train" else test_files) return len(train_files if config["split"] == "train" else test_files)
@register_mock @register_mock(configs=[dict()])
def caltech101(info, root, config): def caltech101(root, config):
def create_ann_file(root, name): def create_ann_file(root, name):
import scipy.io import scipy.io
...@@ -382,15 +396,17 @@ def caltech101(info, root, config): ...@@ -382,15 +396,17 @@ def caltech101(info, root, config):
images_root = root / "101_ObjectCategories" images_root = root / "101_ObjectCategories"
anns_root = root / "Annotations" anns_root = root / "Annotations"
ann_category_map = { image_category_map = {
"Faces_2": "Faces", "Faces": "Faces_2",
"Faces_3": "Faces_easy", "Faces_easy": "Faces_3",
"Motorbikes_16": "Motorbikes", "Motorbikes": "Motorbikes_16",
"Airplanes_Side_2": "airplanes", "airplanes": "Airplanes_Side_2",
} }
categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"]
num_images_per_category = 2 num_images_per_category = 2
for category in info.categories: for category in categories:
create_image_folder( create_image_folder(
root=images_root, root=images_root,
name=category, name=category,
...@@ -399,7 +415,7 @@ def caltech101(info, root, config): ...@@ -399,7 +415,7 @@ def caltech101(info, root, config):
) )
create_ann_folder( create_ann_folder(
root=anns_root, root=anns_root,
name=ann_category_map.get(category, category), name=image_category_map.get(category, category),
file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat", file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat",
num_examples=num_images_per_category, num_examples=num_images_per_category,
) )
...@@ -409,19 +425,26 @@ def caltech101(info, root, config): ...@@ -409,19 +425,26 @@ def caltech101(info, root, config):
make_tar(root, f"{anns_root.name}.tar", anns_root) make_tar(root, f"{anns_root.name}.tar", anns_root)
return num_images_per_category * len(info.categories) return num_images_per_category * len(categories)
@register_mock @register_mock(configs=[dict()])
def caltech256(info, root, config): def caltech256(root, config):
dir = root / "256_ObjectCategories" dir = root / "256_ObjectCategories"
num_images_per_category = 2 num_images_per_category = 2
for idx, category in enumerate(info.categories, 1): categories = [
(1, "ak47"),
(127, "laptop-101"),
(198, "spider"),
(257, "clutter"),
]
for category_idx, category in categories:
files = create_image_folder( files = create_image_folder(
dir, dir,
name=f"{idx:03d}.{category}", name=f"{category_idx:03d}.{category}",
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg", file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg",
num_examples=num_images_per_category, num_examples=num_images_per_category,
) )
if category == "spider": if category == "spider":
...@@ -429,21 +452,21 @@ def caltech256(info, root, config): ...@@ -429,21 +452,21 @@ def caltech256(info, root, config):
make_tar(root, f"{dir.name}.tar", dir) make_tar(root, f"{dir.name}.tar", dir)
return num_images_per_category * len(info.categories) return num_images_per_category * len(categories)
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test")))
def imagenet(info, root, config): def imagenet(root, config):
from scipy.io import savemat from scipy.io import savemat
categories = info.categories info = datasets.info("imagenet")
wnids = [info.extra.category_to_wnid[category] for category in categories]
if config.split == "train": if config["split"] == "train":
num_samples = len(wnids) num_samples = len(info["wnids"])
archive_name = "ILSVRC2012_img_train.tar" archive_name = "ILSVRC2012_img_train.tar"
files = [] files = []
for wnid in wnids: for wnid in info["wnids"]:
create_image_folder( create_image_folder(
root=root, root=root,
name=wnid, name=wnid,
...@@ -451,7 +474,7 @@ def imagenet(info, root, config): ...@@ -451,7 +474,7 @@ def imagenet(info, root, config):
num_examples=1, num_examples=1,
) )
files.append(make_tar(root, f"{wnid}.tar")) files.append(make_tar(root, f"{wnid}.tar"))
elif config.split == "val": elif config["split"] == "val":
num_samples = 3 num_samples = 3
archive_name = "ILSVRC2012_img_val.tar" archive_name = "ILSVRC2012_img_val.tar"
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
...@@ -461,13 +484,13 @@ def imagenet(info, root, config): ...@@ -461,13 +484,13 @@ def imagenet(info, root, config):
data_root.mkdir(parents=True) data_root.mkdir(parents=True)
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist():
file.write(f"{label}\n") file.write(f"{label}\n")
num_children = 0 num_children = 0
synsets = [ synsets = [
(idx, wnid, category, "", num_children, [], 0, 0) (idx, wnid, category, "", num_children, [], 0, 0)
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1)
] ]
num_children = 1 num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
...@@ -477,7 +500,7 @@ def imagenet(info, root, config): ...@@ -477,7 +500,7 @@ def imagenet(info, root, config):
savemat(data_root / "meta.mat", dict(synsets=synsets)) savemat(data_root / "meta.mat", dict(synsets=synsets))
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test" else: # config["split"] == "test"
num_samples = 5 num_samples = 5
archive_name = "ILSVRC2012_img_test_v10102019.tar" archive_name = "ILSVRC2012_img_test_v10102019.tar"
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
...@@ -592,9 +615,15 @@ class CocoMockData: ...@@ -592,9 +615,15 @@ class CocoMockData:
return num_samples return num_samples
@register_mock @register_mock(
def coco(info, root, config): configs=combinations_grid(
return CocoMockData.generate(root, year=config.year, num_samples=5) split=("train", "val"),
year=("2017", "2014"),
annotations=("instances", "captions", None),
)
)
def coco(root, config):
return CocoMockData.generate(root, year=config["year"], num_samples=5)
class SBDMockData: class SBDMockData:
...@@ -666,15 +695,15 @@ class SBDMockData: ...@@ -666,15 +695,15 @@ class SBDMockData:
return num_samples_map return num_samples_map
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
def sbd(info, root, config): def sbd(root, config):
return SBDMockData.generate(root)[config.split] return SBDMockData.generate(root)[config["split"]]
@register_mock @register_mock(configs=[dict()])
def semeion(info, root, config): def semeion(root, config):
num_samples = 3 num_samples = 3
num_categories = len(info.categories) num_categories = 10
images = torch.rand(num_samples, 256) images = torch.rand(num_samples, 256)
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
...@@ -784,10 +813,23 @@ class VOCMockData: ...@@ -784,10 +813,23 @@ class VOCMockData:
return num_samples_map return num_samples_map
@register_mock @register_mock(
def voc(info, root, config): configs=[
trainval = config.split != "test" *combinations_grid(
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] split=("train", "val", "trainval"),
year=("2007", "2008", "2009", "2010", "2011", "2012"),
task=("detection", "segmentation"),
),
*combinations_grid(
split=("test",),
year=("2007",),
task=("detection", "segmentation"),
),
],
)
def voc(root, config):
trainval = config["split"] != "test"
return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]]
class CelebAMockData: class CelebAMockData:
...@@ -878,19 +920,14 @@ class CelebAMockData: ...@@ -878,19 +920,14 @@ class CelebAMockData:
return num_samples_map return num_samples_map
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test")))
def celeba(info, root, config): def celeba(root, config):
return CelebAMockData.generate(root)[config.split] return CelebAMockData.generate(root)[config["split"]]
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test")))
def country211(info, root, config): def country211(root, config):
split_name_mapper = { split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"])
"train": "train",
"val": "valid",
"test": "test",
}
split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]])
split_folder.mkdir(parents=True, exist_ok=True) split_folder.mkdir(parents=True, exist_ok=True)
num_examples = { num_examples = {
...@@ -911,8 +948,8 @@ def country211(info, root, config): ...@@ -911,8 +948,8 @@ def country211(info, root, config):
return num_examples * len(classes) return num_examples * len(classes)
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def food101(info, root, config): def food101(root, config):
data_folder = root / "food-101" data_folder = root / "food-101"
num_images_per_class = 3 num_images_per_class = 3
...@@ -946,11 +983,11 @@ def food101(info, root, config): ...@@ -946,11 +983,11 @@ def food101(info, root, config):
make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") make_tar(root, f"{data_folder.name}.tar.gz", compression="gz")
return num_samples_map[config.split] return num_samples_map[config["split"]]
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10)))
def dtd(info, root, config): def dtd(root, config):
data_folder = root / "dtd" data_folder = root / "dtd"
num_images_per_class = 3 num_images_per_class = 3
...@@ -990,20 +1027,21 @@ def dtd(info, root, config): ...@@ -990,20 +1027,21 @@ def dtd(info, root, config):
with open(meta_folder / f"{split}{fold}.txt", "w") as file: with open(meta_folder / f"{split}{fold}.txt", "w") as file:
file.write("\n".join(image_ids_in_config) + "\n") file.write("\n".join(image_ids_in_config) + "\n")
num_samples_map[info.make_config(split=split, fold=str(fold))] = len(image_ids_in_config) num_samples_map[(split, fold)] = len(image_ids_in_config)
make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz")
return num_samples_map[config] return num_samples_map[config["split"], config["fold"]]
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def fer2013(info, root, config): def fer2013(root, config):
num_samples = 5 if config.split == "train" else 3 split = config["split"]
num_samples = 5 if split == "train" else 3
path = root / f"{config.split}.csv" path = root / f"{split}.csv"
with open(path, "w", newline="") as file: with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else [] field_names = ["emotion"] if split == "train" else []
field_names.append("pixels") field_names.append("pixels")
file.write(",".join(field_names) + "\n") file.write(",".join(field_names) + "\n")
...@@ -1013,7 +1051,7 @@ def fer2013(info, root, config): ...@@ -1013,7 +1051,7 @@ def fer2013(info, root, config):
rowdict = { rowdict = {
"pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)])
} }
if config.split == "train": if split == "train":
rowdict["emotion"] = int(torch.randint(7, ())) rowdict["emotion"] = int(torch.randint(7, ()))
writer.writerow(rowdict) writer.writerow(rowdict)
...@@ -1022,9 +1060,9 @@ def fer2013(info, root, config): ...@@ -1022,9 +1060,9 @@ def fer2013(info, root, config):
return num_samples return num_samples
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def gtsrb(info, root, config): def gtsrb(root, config):
num_examples_per_class = 5 if config.split == "train" else 3 num_examples_per_class = 5 if config["split"] == "train" else 3
classes = ("00000", "00042", "00012") classes = ("00000", "00042", "00012")
num_examples = num_examples_per_class * len(classes) num_examples = num_examples_per_class * len(classes)
...@@ -1092,8 +1130,8 @@ def gtsrb(info, root, config): ...@@ -1092,8 +1130,8 @@ def gtsrb(info, root, config):
return num_examples return num_examples
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test")))
def clevr(info, root, config): def clevr(root, config):
data_folder = root / "CLEVR_v1.0" data_folder = root / "CLEVR_v1.0"
num_samples_map = { num_samples_map = {
...@@ -1134,7 +1172,7 @@ def clevr(info, root, config): ...@@ -1134,7 +1172,7 @@ def clevr(info, root, config):
make_zip(root, f"{data_folder.name}.zip", data_folder) make_zip(root, f"{data_folder.name}.zip", data_folder)
return num_samples_map[config.split] return num_samples_map[config["split"]]
class OxfordIIITPetMockData: class OxfordIIITPetMockData:
...@@ -1198,9 +1236,9 @@ class OxfordIIITPetMockData: ...@@ -1198,9 +1236,9 @@ class OxfordIIITPetMockData:
return num_samples_map return num_samples_map
@register_mock @register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test")))
def oxford_iiit_pet(info, root, config): def oxford_iiit_pet(root, config):
return OxfordIIITPetMockData.generate(root)[config.split] return OxfordIIITPetMockData.generate(root)[config["split"]]
class _CUB200MockData: class _CUB200MockData:
...@@ -1364,14 +1402,14 @@ class CUB2002010MockData(_CUB200MockData): ...@@ -1364,14 +1402,14 @@ class CUB2002010MockData(_CUB200MockData):
return num_samples_map return num_samples_map
@register_mock @register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011")))
def cub200(info, root, config): def cub200(root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root)
return num_samples_map[config.split] return num_samples_map[config["split"]]
@register_mock @register_mock(configs=[dict()])
def eurosat(info, root, config): def eurosat(root, config):
data_folder = root / "2750" data_folder = root / "2750"
data_folder.mkdir(parents=True) data_folder.mkdir(parents=True)
...@@ -1388,18 +1426,18 @@ def eurosat(info, root, config): ...@@ -1388,18 +1426,18 @@ def eurosat(info, root, config):
return len(categories) * num_examples_per_class return len(categories) * num_examples_per_class
@register_mock @register_mock(configs=combinations_grid(split=("train", "test", "extra")))
def svhn(info, root, config): def svhn(root, config):
import scipy.io as sio import scipy.io as sio
num_samples = { num_samples = {
"train": 2, "train": 2,
"test": 3, "test": 3,
"extra": 4, "extra": 4,
}[config.split] }[config["split"]]
sio.savemat( sio.savemat(
root / f"{config.split}_32x32.mat", root / f"{config['split']}_32x32.mat",
{ {
"X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8),
"y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8),
...@@ -1408,13 +1446,13 @@ def svhn(info, root, config): ...@@ -1408,13 +1446,13 @@ def svhn(info, root, config):
return num_samples return num_samples
@register_mock @register_mock(configs=combinations_grid(split=("train", "val", "test")))
def pcam(info, root, config): def pcam(root, config):
import h5py import h5py
num_images = {"train": 2, "test": 3, "val": 4}[config.split] num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]
split = "valid" if config.split == "val" else config.split split = "valid" if config["split"] == "val" else config["split"]
images_io = io.BytesIO() images_io = io.BytesIO()
with h5py.File(images_io, "w") as f: with h5py.File(images_io, "w") as f:
...@@ -1435,18 +1473,19 @@ def pcam(info, root, config): ...@@ -1435,18 +1473,19 @@ def pcam(info, root, config):
return num_images return num_images
@register_mock @register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test")))
def stanford_cars(info, root, config): def stanford_cars(root, config):
import scipy.io as io import scipy.io as io
from numpy.core.records import fromarrays from numpy.core.records import fromarrays
num_samples = {"train": 5, "test": 7}[config["split"]] split = config["split"]
num_samples = {"train": 5, "test": 7}[split]
num_categories = 3 num_categories = 3
devkit = root / "devkit" devkit = root / "devkit"
devkit.mkdir(parents=True) devkit.mkdir(parents=True)
if config["split"] == "train": if split == "train":
images_folder_name = "cars_train" images_folder_name = "cars_train"
annotations_mat_path = devkit / "cars_train_annos.mat" annotations_mat_path = devkit / "cars_train_annos.mat"
else: else:
...@@ -1460,7 +1499,7 @@ def stanford_cars(info, root, config): ...@@ -1460,7 +1499,7 @@ def stanford_cars(info, root, config):
num_examples=num_samples, num_examples=num_samples,
) )
make_tar(root, f"cars_{config.split}.tgz", images_folder_name) make_tar(root, f"cars_{split}.tgz", images_folder_name)
bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8)
classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8)
fnames = [f"{i:5d}.jpg" for i in range(num_samples)] fnames = [f"{i:5d}.jpg" for i in range(num_samples)]
...@@ -1470,17 +1509,17 @@ def stanford_cars(info, root, config): ...@@ -1470,17 +1509,17 @@ def stanford_cars(info, root, config):
) )
io.savemat(annotations_mat_path, {"annotations": rec_array}) io.savemat(annotations_mat_path, {"annotations": rec_array})
if config.split == "train": if split == "train":
make_tar(root, "car_devkit.tgz", devkit, compression="gz") make_tar(root, "car_devkit.tgz", devkit, compression="gz")
return num_samples return num_samples
@register_mock @register_mock(configs=combinations_grid(split=("train", "test")))
def usps(info, root, config): def usps(root, config):
num_samples = {"train": 15, "test": 7}[config.split] num_samples = {"train": 15, "test": 7}[config["split"]]
with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh:
lines = [] lines = []
for _ in range(num_samples): for _ in range(num_samples):
label = make_tensor(1, low=1, high=11, dtype=torch.int) label = make_tensor(1, low=1, high=11, dtype=torch.int)
......
...@@ -7,9 +7,10 @@ import pytest ...@@ -7,9 +7,10 @@ import pytest
import torch import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
...@@ -42,14 +43,24 @@ def test_coverage(): ...@@ -42,14 +43,24 @@ def test_coverage():
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
class TestCommon: class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
def test_info(self, name):
try:
info = datasets.info(name)
except ValueError:
raise AssertionError("No info available.") from None
if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
raise AssertionError("Info should be a dictionary with string keys.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config): def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config): def test_sample(self, test_home, dataset_mock, config):
...@@ -76,24 +87,7 @@ class TestCommon: ...@@ -76,24 +87,7 @@ class TestCommon:
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0 assert len(list(dataset)) == mock_info["num_samples"]
for _ in dataset:
num_samples += 1
assert num_samples == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config): def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
...@@ -116,14 +110,36 @@ class TestCommon: ...@@ -116,14 +110,36 @@ class TestCommon:
next(iter(dataset.map(transforms.Identity()))) next(iter(dataset.map(transforms.Identity())))
@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config): def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset, only_datapipe=only_datapipe)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
pickle.dumps(dataset) pickle.dumps(dataset)
@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
dl = DataLoader(
dataset,
batch_size=2,
num_workers=num_workers,
collate_fn=lambda batch: batch,
)
next(iter(dl))
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
...@@ -132,7 +148,6 @@ class TestCommon: ...@@ -132,7 +148,6 @@ class TestCommon:
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
...@@ -160,6 +175,13 @@ class TestCommon: ...@@ -160,6 +175,13 @@ class TestCommon:
# resolved # resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
assert len(dataset) > 0
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST: class TestQMNIST:
...@@ -186,7 +208,7 @@ class TestGTSRB: ...@@ -186,7 +208,7 @@ class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config): def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same # This test makes sure that they're both the same
if config.split != "train": if config["split"] != "train":
return return
dataset_mock.prepare(test_home, config) dataset_mock.prepare(test_home, config)
......
import unittest.mock
import pytest
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)
class TestFrozenMapping:
@pytest.mark.parametrize(
("args", "kwargs"),
[
pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"),
pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"),
pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"),
],
)
def test_instantiation(self, args, kwargs):
FrozenMapping(*args, **kwargs)
def test_unhashable_items(self):
with pytest.raises(TypeError, match="unhashable type"):
FrozenMapping(foo=[])
def test_getitem(self):
options = dict(foo="bar", baz=1)
config = FrozenMapping(options)
for key, value in options.items():
assert config[key] == value
def test_getitem_unknown(self):
with pytest.raises(KeyError):
FrozenMapping()["unknown"]
def test_iter(self):
options = dict(foo="bar", baz=1)
assert set(iter(FrozenMapping(options))) == set(options.keys())
def test_len(self):
options = dict(foo="bar", baz=1)
assert len(FrozenMapping(options)) == len(options)
def test_immutable_setitem(self):
frozen_mapping = FrozenMapping()
with pytest.raises(RuntimeError, match="immutable"):
frozen_mapping["foo"] = "bar"
def test_immutable_delitem(
self,
):
frozen_mapping = FrozenMapping(foo="bar")
with pytest.raises(RuntimeError, match="immutable"):
del frozen_mapping["foo"]
def test_eq(self):
options = dict(foo="bar", baz=1)
assert FrozenMapping(options) == FrozenMapping(options)
def test_ne(self):
options1 = dict(foo="bar", baz=1)
options2 = options1.copy()
options2["baz"] += 1
assert FrozenMapping(options1) != FrozenMapping(options2)
def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenMapping(options))
assert isinstance(output, str)
for key, value in options.items():
assert str(key) in output and str(value) in output
class TestFrozenBunch:
def test_getattr(self):
options = dict(foo="bar", baz=1)
config = FrozenBunch(options)
for key, value in options.items():
assert getattr(config, key) == value
def test_getattr_unknown(self):
with pytest.raises(AttributeError, match="no attribute 'unknown'"):
datasets.utils.DatasetConfig().unknown
def test_immutable_setattr(self):
frozen_bunch = FrozenBunch()
with pytest.raises(RuntimeError, match="immutable"):
frozen_bunch.foo = "bar"
def test_immutable_delattr(
self,
):
frozen_bunch = FrozenBunch(foo="bar")
with pytest.raises(RuntimeError, match="immutable"):
del frozen_bunch.foo
def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenBunch(options))
assert isinstance(output, str)
assert output.startswith("FrozenBunch")
for key, value in options.items():
assert f"{key}={value}" in output
class TestDatasetInfo:
@pytest.fixture
def info(self):
return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz")))
def test_default_config(self, info):
valid_options = info._valid_options
default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()})
assert info.default_config == default_config
@pytest.mark.parametrize(
("valid_options", "options", "expected_error_msg"),
[
(dict(), dict(any_option=None), "does not take any options"),
(dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"),
(dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"),
],
)
def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg):
info = make_minimal_dataset_info(valid_options=valid_options)
with pytest.raises(ValueError, match=expected_error_msg):
info.make_config(**options)
def test_check_dependencies(self):
dependency = "fake_dependency"
info = make_minimal_dataset_info(dependencies=(dependency,))
with pytest.raises(ModuleNotFoundError, match=dependency):
info.check_dependencies()
def test_repr(self, info):
output = repr(info)
assert isinstance(output, str)
assert "DatasetInfo" in output
for key, value in info._valid_options.items():
assert f"{key}={str(value)[1:-1]}" in output
@pytest.mark.parametrize("optional_info", ("citation", "homepage", "license"))
def test_repr_optional_info(self, optional_info):
sentinel = "sentinel"
info = make_minimal_dataset_info(**{optional_info: sentinel})
assert f"{optional_info}={sentinel}" in repr(info)
class TestDataset:
class DatasetMock(datasets.utils.Dataset):
def __init__(self, info=None, *, resources=None):
self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test")))
self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources
self._make_datapipe = unittest.mock.Mock()
super().__init__()
def _make_info(self):
return self._info
def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
def _make_datapipe(self, resource_dps, *, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
def test_name(self):
name = "sentinel"
dataset = self.DatasetMock(make_minimal_dataset_info(name=name))
assert dataset.name == name
def test_default_config(self):
sentinel = "sentinel"
dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train"))))
assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel)
@pytest.mark.parametrize(
("config", "kwarg"),
[
pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"),
pytest.param(DatasetMock().default_config, None, id="default"),
],
)
def test_load_config(self, config, kwarg):
dataset = self.DatasetMock()
dataset.load("", config=kwarg)
dataset.resources.assert_called_with(config)
_, call_kwargs = dataset._make_datapipe.call_args
assert call_kwargs["config"] == config
def test_missing_dependencies(self):
dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.load("root")
def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["load"])
sentinel = object()
resource_mock.load.return_value = sentinel
dataset = self.DatasetMock(resources=[resource_mock])
root = "root"
dataset.load(root)
(call_args, _) = resource_mock.load.call_args
assert call_args[0] == root
(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
from datasets_utils import make_fake_flo_file from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
...@@ -101,3 +101,21 @@ class TestHttpResource: ...@@ -101,3 +101,21 @@ class TestHttpResource:
assert redirected_resource.file_name == file_name assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel assert redirected_resource._preprocess is preprocess_sentinel
def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))
def _resources(self):
pass
def _datapipe(self, resource_dps):
pass
def __len__(self):
pass
with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
...@@ -10,5 +10,6 @@ from . import utils ...@@ -10,5 +10,6 @@ from . import utils
from ._home import home from ._home import home
# Load this last, since some parts depend on the above being loaded first # Load this last, since some parts depend on the above being loaded first
from ._api import list_datasets, info, load # usort: skip from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip
from ._folder import from_data_folder, from_image_folder from ._folder import from_data_folder, from_image_folder
from ._builtin import *
import os import pathlib
from typing import Any, Dict, List from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion from torchvision.prototype.utils._internal import add_suggestion
from . import _builtin
DATASETS: Dict[str, Dataset] = {} T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset])
BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}
def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset
def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]:
def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_INFOS[name] = fn()
return fn
for name, obj in _builtin.__dict__.items(): return wrapper
if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
register(obj())
BUILTIN_DATASETS = {}
def register_dataset(name: str) -> Callable[[D], D]:
def wrapper(dataset_cls: D) -> D:
BUILTIN_DATASETS[name] = dataset_cls
return dataset_cls
return wrapper
def list_datasets() -> List[str]: def list_datasets() -> List[str]:
return sorted(DATASETS.keys()) return sorted(BUILTIN_DATASETS.keys())
def find(name: str) -> Dataset: def find(dct: Dict[str, T], name: str) -> T:
name = name.lower() name = name.lower()
try: try:
return DATASETS[name] return dct[name]
except KeyError as error: except KeyError as error:
raise ValueError( raise ValueError(
add_suggestion( add_suggestion(
f"Unknown dataset '{name}'.", f"Unknown dataset '{name}'.",
word=name, word=name,
possibilities=DATASETS.keys(), possibilities=dct.keys(),
alternative_hint=lambda _: ( alternative_hint=lambda _: (
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets." "You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
), ),
...@@ -41,19 +52,14 @@ def find(name: str) -> Dataset: ...@@ -41,19 +52,14 @@ def find(name: str) -> Dataset:
) from error ) from error
def info(name: str) -> DatasetInfo: def info(name: str) -> Dict[str, Any]:
return find(name).info return find(BUILTIN_INFOS, name)
def load( def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
name: str, dataset_cls = find(BUILTIN_DATASETS, name)
*,
skip_integrity_check: bool = False,
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(**options) if root is None:
root = os.path.join(home(), dataset.name) root = pathlib.Path(home()) / name
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) return dataset_cls(root, **config)
...@@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below. ...@@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below.
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin`
that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that
module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in
discussed in detail below: detail below:
```python ```python
from typing import Any, Dict, List import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from .._api import register_dataset, register_info
NAME = "my-dataset"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
...
)
@register_dataset(NAME)
class MyDataset(Dataset): class MyDataset(Dataset):
def _make_info(self) -> DatasetInfo: def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None:
... ...
super().__init__(root, skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
... ...
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, ...
) -> IterDataPipe[Dict[str, Any]]:
def __len__(self) -> int:
... ...
``` ```
### `_make_info(self)` In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a
dictionary of static information. The most common use case is to provide human-readable categories.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
The `DatasetInfo` carries static information about the dataset. There are two required fields: Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively.
- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain ### `__init__(self, root, *, ..., skip_integrity_check = False)`
lowercase characters.
There are more optional parameters that can be passed: Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke
the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
an underscore.
- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
availability will be automatically checked if a user tries to load the dataset. Within the implementation, import class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically
these packages lazily to avoid missing dependencies at import time. checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the avoid missing dependencies at import time.
corresponding label returned in the dataset samples.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[Any]]`.
The options are accessible through the `config` namespace in the other two functions. First value of the sequence is
taken as default if the user passes no option to `torchvision.prototype.datasets.load()`.
## `resources(self, config)` ### `_resources(self)`
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be
specific `config` can be build. The download will happen automatically. build. The download will happen automatically.
Currently, the following `OnlineResource`'s are supported: Currently, the following `OnlineResource`'s are supported:
...@@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024): ...@@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024):
print(checksum.hexdigest()) print(checksum.hexdigest())
``` ```
### `_make_datapipe(resource_dps, *, config)` ### `_datapipe(self, resource_dps)`
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone
...@@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f ...@@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f
needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example. to add one. See the MNIST or CelebA datasets for example.
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return `_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
of such tuples for the file specified by the resource. of such tuples for the file specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and
`Grouper`. There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since `Grouper`. There are two issues with that:
most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered
that is only needed at runtime. 1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images. trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take `hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal`
required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!). names (yet!).
### `__len__`
This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the
readability. For example, `1_281_167` vs. `1281167`.
If there are only two different numbers, a simple `if` / `else` is fine:
```py
def __len__(self):
return 12_345 if self._split == "train" else 6_789
```
If there are more options, using a dictionary usually is the most readable option:
```py
def __len__(self):
return {
"train": 3,
"val": 2,
"test": 1,
}[self._split]
```
If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
```py
def __len__(self):
return {
("train", "bar"): 4,
("train", "baz"): 3,
("test", "bar"): 2,
("test", "baz"): 1,
}[(self._split, self._foo)]
```
The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way
is to define a dummy method like
```py
def __len__(self):
return 1
```
and only fill it with the correct data if the implementation is otherwise finished.
[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples.
## Tests ## Tests
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data. To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
This mock-up should resemble the original data as close as necessary, while containing only few examples. This mock-up should resemble the original data as close as necessary, while containing only few examples.
To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the
same name as you have defined in `_make_config()` (if the name includes hyphens `-`, replace them with underscores `_`) same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function".
and decorate it with `@register_mock`: Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset
will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options,
you can use the `combinations_grid()` helper function, e.g.
`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`.
In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass
the `name` parameter to `@register_mock`
```py ```py
# this is defined in torchvision/prototype/datasets/_builtin # this is defined in torchvision/prototype/datasets/_builtin
@register_dataset("my-dataset")
class MyDataset(Dataset): class MyDataset(Dataset):
def _make_info(self) -> DatasetInfo: ...
return DatasetInfo(
"my-dataset", @register_mock(name="my-dataset", configs=...)
... def my_dataset(root, config):
)
@register_mock
def my_dataset(info, root, config):
... ...
``` ```
The function receives three arguments: The mock data function receives two arguments:
- `info`: The return value of `_make_info()`.
- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data - `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data
needs to be placed. needs to be placed.
- `config`: The configuration to generate the data for. This is the same value that `_make_datapipe()` receives. - `config`: The configuration to generate the data for. This is one of the dictionaries defined in
`@register_mock(configs=...)`
The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if
the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
`config`. Although this seems odd at first, this is important. Consider the following original data setup: the current `config`. Although this seems odd at first, this is important. Consider the following original data setup:
``` ```
root root
...@@ -167,9 +234,8 @@ root ...@@ -167,9 +234,8 @@ root
For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to
load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is
present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in
`_make_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data `_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for
for the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data.
data.
For datasets that are ported from the old API, we already have some mock data in For datasets that are ported from the old API, we already have some mock data in
[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there [`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there
...@@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences ...@@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences
- `tmp_dir` corresponds to `root`, but is a `str` rather than a - `tmp_dir` corresponds to `root`, but is a `str` rather than a
[`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like
`folder = pathlib.Path(tmp_dir)`. This is not needed. `folder = pathlib.Path(tmp_dir)`. This is not needed.
- Although both parameters are called `config`, the value in the new tests is a namespace. Thus, please use `config.foo`
over `config["foo"]` to enhance readability.
- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the - The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the
new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files
specified in the dataset. specified in the dataset.
...@@ -196,9 +260,9 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets ...@@ -196,9 +260,9 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets
### How do I start? ### How do I start?
Get the skeleton of your dataset class ready with all 3 methods. For `_make_datapipe()`, you can just do Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do
`return resources_dp[0]` to get started. Then import the dataset class in `return resources_dp[0]` to get started. Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset and it will be `torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be
instantiable via `datasets.load("mydataset")`. On a separate script, try something like instantiable via `datasets.load("mydataset")`. On a separate script, try something like
```py ```py
...@@ -206,7 +270,7 @@ from torchvision.prototype import datasets ...@@ -206,7 +270,7 @@ from torchvision.prototype import datasets
dataset = datasets.load("mydataset") dataset = datasets.load("mydataset")
for sample in dataset: for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _make_datapipe() print(sample) # this is the content of an item in datapipe returned by _datapipe()
break break
# Or you can also inspect the sample in a debugger # Or you can also inspect the sample in a debugger
``` ```
...@@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format. ...@@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format.
### How do I handle a dataset that defines many categories? ### How do I handle a dataset that defines many categories?
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more
fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a
specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file`
automatically loaded if `categories=` is not set. function and pass it `$NAME`.
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method
passed the `root` path to the resources, but they have to be manually loaded, e.g. should return a sequence of strings representing the category names. In the method body, you'll have to manually load
`self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. the resources, e.g.
```py
resources = self._resources()
dp = resources[0].load(self._root)
```
Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that.
To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
### What if a resource file forms an I/O bottleneck? ### What if a resource file forms an I/O bottleneck?
...@@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed. ...@@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed.
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be `preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
accepts `"decompress"` and `"extract"` to handle these common scenarios. accepts `"decompress"` and `"extract"` to handle these common scenarios.
### How do I compute the number of samples?
Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
than to iterate over the dataset and count the number of samples:
```py
import itertools
from torchvision.prototype import datasets
def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
configs = combinations_grid(split=("train", "test"), foo=("bar", "baz"))
for config in configs:
dataset = datasets.load("my-dataset", **config)
num_samples = 0
for _ in dataset:
num_samples += 1
print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples)
```
To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
files.
...@@ -12,7 +12,7 @@ from .food101 import Food101 ...@@ -12,7 +12,7 @@ from .food101 import Food101
from .gtsrb import GTSRB from .gtsrb import GTSRB
from .imagenet import ImageNet from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM from .pcam import PCAM
from .sbd import SBD from .sbd import SBD
from .semeion import SEMEION from .semeion import SEMEION
......
import pathlib import pathlib
import re import re
from typing import Any, Dict, List, Tuple, BinaryIO from typing import Any, Dict, List, Tuple, BinaryIO, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -9,26 +9,46 @@ from torchdata.datapipes.iter import ( ...@@ -9,26 +9,46 @@ from torchdata.datapipes.iter import (
Filter, Filter,
IterKeyZipper, IterKeyZipper,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
Dataset, from torchvision.prototype.datasets.utils._internal import (
DatasetConfig, INFINITE_BUFFER_SIZE,
DatasetInfo, read_mat,
HttpResource, hint_sharding,
OnlineResource, hint_shuffling,
read_categories_file,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
from .._api import register_dataset, register_info
@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech101"))
@register_dataset("caltech101")
class Caltech101(Dataset): class Caltech101(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
"caltech101", - **dependencies**:
- <scipy `https://scipy.org/`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech101_info()["categories"]
super().__init__(
root,
dependencies=("scipy",), dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", skip_integrity_check=skip_integrity_check,
) )
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
images = HttpResource( images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
...@@ -88,7 +108,7 @@ class Caltech101(Dataset): ...@@ -88,7 +108,7 @@ class Caltech101(Dataset):
ann = read_mat(ann_buffer) ann = read_mat(ann_buffer)
return dict( return dict(
label=Label.from_category(category, categories=self.categories), label=Label.from_category(category, categories=self._categories),
image_path=image_path, image_path=image_path,
image=image, image=image,
ann_path=ann_path, ann_path=ann_path,
...@@ -98,12 +118,7 @@ class Caltech101(Dataset): ...@@ -98,12 +118,7 @@ class Caltech101(Dataset):
contour=_Feature(ann["obj_contour"].T), contour=_Feature(ann["obj_contour"].T),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
...@@ -122,23 +137,39 @@ class Caltech101(Dataset): ...@@ -122,23 +137,39 @@ class Caltech101(Dataset):
) )
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def __len__(self) -> int:
resources = self.resources(self.default_config) return 8677
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, self._is_not_background_image) dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) return sorted({pathlib.Path(path).parent.name for path, _ in dp})
@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech256"))
@register_dataset("caltech256")
class Caltech256(Dataset): class Caltech256(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
"caltech256", """
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
) def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech256_info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
...@@ -156,25 +187,23 @@ class Caltech256(Dataset): ...@@ -156,25 +187,23 @@ class Caltech256(Dataset):
return dict( return dict(
path=path, path=path,
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories), label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def __len__(self) -> int:
resources = self.resources(self.default_config) return 30607
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp} dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)] return [name.split(".")[1] for name in sorted(dir_names)]
import csv import csv
import functools import pathlib
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
...@@ -11,8 +11,6 @@ from torchdata.datapipes.iter import ( ...@@ -11,8 +11,6 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
GDriveResource, GDriveResource,
OnlineResource, OnlineResource,
) )
...@@ -25,6 +23,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -25,6 +23,7 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
from .._api import register_dataset, register_info
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
...@@ -60,15 +59,32 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -60,15 +59,32 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
yield line.pop("image_id"), line yield line.pop("image_id"), line
NAME = "celeba"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CelebA(Dataset): class CelebA(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"celeba", """
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
splits = GDriveResource( splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk", "0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
...@@ -101,14 +117,13 @@ class CelebA(Dataset): ...@@ -101,14 +117,13 @@ class CelebA(Dataset):
) )
return [splits, images, identities, attributes, bounding_boxes, landmarks] return [splits, images, identities, attributes, bounding_boxes, landmarks]
_SPLIT_ID_TO_NAME = { def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool:
"0": "train", split_id = {
"1": "val", "train": "0",
"2": "test", "val": "1",
} "test": "2",
}[self._split]
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return data[1]["split_id"] == split_id
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -145,16 +160,11 @@ class CelebA(Dataset): ...@@ -145,16 +160,11 @@ class CelebA(Dataset):
}, },
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) splits_dp = Filter(splits_dp, self._filter_split)
splits_dp = hint_shuffling(splits_dp) splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp) splits_dp = hint_sharding(splits_dp)
...@@ -186,3 +196,10 @@ class CelebA(Dataset): ...@@ -186,3 +196,10 @@ class CelebA(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 162_770,
"val": 19_867,
"test": 19_962,
}[self._split]
import abc import abc
import functools
import io import io
import pathlib import pathlib
import pickle import pickle
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -11,20 +10,17 @@ from torchdata.datapipes.iter import ( ...@@ -11,20 +10,17 @@ from torchdata.datapipes.iter import (
Filter, Filter,
Mapper, Mapper,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_shuffling, hint_shuffling,
path_comparator, path_comparator,
hint_sharding, hint_sharding,
read_categories_file,
) )
from torchvision.prototype.features import Label, Image from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
...@@ -44,19 +40,23 @@ class _CifarBase(Dataset): ...@@ -44,19 +40,23 @@ class _CifarBase(Dataset):
_LABELS_KEY: str _LABELS_KEY: str
_META_FILE_NAME: str _META_FILE_NAME: str
_CATEGORIES_KEY: str _CATEGORIES_KEY: str
_categories: List[str]
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
@abc.abstractmethod @abc.abstractmethod
def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
pass pass
def _make_info(self) -> DatasetInfo: def _resources(self) -> List[OnlineResource]:
return DatasetInfo(
type(self).__name__.lower(),
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(split=("train", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
...@@ -72,52 +72,72 @@ class _CifarBase(Dataset): ...@@ -72,52 +72,72 @@ class _CifarBase(Dataset):
image_array, category_idx = data image_array, category_idx = data
return dict( return dict(
image=Image(image_array), image=Image(image_array),
label=Label(category_idx, categories=self.categories), label=Label(category_idx, categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) dp = Filter(dp, self._is_data_file)
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def __len__(self) -> int:
resources = self.resources(self.default_config) return 50_000 if self._split == "train" else 10_000
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
@register_info("cifar10")
def _cifar10_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar10"))
@register_dataset("cifar10")
class Cifar10(_CifarBase): class Cifar10(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-10-python.tar.gz" _FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels" _LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta" _META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names" _CATEGORIES_KEY = "label_names"
_categories = _cifar10_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.name.startswith("data" if split == "train" else "test") return path.name.startswith("data" if self._split == "train" else "test")
@register_info("cifar100")
def _cifar100_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar100"))
@register_dataset("cifar100")
class Cifar100(_CifarBase): class Cifar100(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-100-python.tar.gz" _FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels" _LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta" _META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names" _CATEGORIES_KEY = "fine_label_names"
_categories = _cifar100_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.name == split return path.name == self._split
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
hint_sharding, hint_sharding,
...@@ -19,16 +13,30 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -19,16 +13,30 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "clevr"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CLEVR(Dataset): class CLEVR(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"clevr", """
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
valid_options=dict(split=("train", "val", "test")), def __init__(
) self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
def resources(self, config: DatasetConfig) -> List[OnlineResource]: super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource( archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
...@@ -61,12 +69,7 @@ class CLEVR(Dataset): ...@@ -61,12 +69,7 @@ class CLEVR(Dataset):
label=Label(len(scenes_data["objects"])) if scenes_data else None, label=Label(len(scenes_data["objects"])) if scenes_data else None,
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer( images_dp, scenes_dp = Demultiplexer(
archive_dp, archive_dp,
...@@ -76,12 +79,12 @@ class CLEVR(Dataset): ...@@ -76,12 +79,12 @@ class CLEVR(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
images_dp = hint_shuffling(images_dp) images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp) images_dp = hint_sharding(images_dp)
if config.split != "test": if self._split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp) scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp) scenes_dp = UnBatcher(scenes_dp)
...@@ -97,3 +100,6 @@ class CLEVR(Dataset): ...@@ -97,3 +100,6 @@ class CLEVR(Dataset):
dp = Mapper(images_dp, self._add_empty_anns) dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 70_000 if self._split == "train" else 15_000
import functools
import pathlib import pathlib
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
...@@ -16,43 +16,65 @@ from torchdata.datapipes.iter import ( ...@@ -16,43 +16,65 @@ from torchdata.datapipes.iter import (
UnBatcher, UnBatcher,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
Dataset,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
MappingIterator, MappingIterator,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
getitem, getitem,
read_categories_file,
path_accessor, path_accessor,
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
from .._api import register_dataset, register_info
NAME = "coco"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*read_categories_file(NAME))
return dict(categories=categories, super_categories=super_categories)
@register_dataset(NAME)
class Coco(Dataset): class Coco(Dataset):
def _make_info(self) -> DatasetInfo: """
name = "coco" - **homepage**: https://cocodataset.org/
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
return DatasetInfo( """
name,
dependencies=("pycocotools",), def __init__(
categories=categories, self,
homepage="https://cocodataset.org/", root: Union[str, pathlib.Path],
valid_options=dict( *,
split=("train", "val"), split: str = "train",
year=("2017", "2014"), year: str = "2017",
annotations=(*self._ANN_DECODERS.keys(), None), annotations: Optional[str] = "instances",
), skip_integrity_check: bool = False,
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), ) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
) )
info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))
super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
_IMAGES_CHECKSUMS = { _IMAGES_CHECKSUMS = {
...@@ -69,14 +91,14 @@ class Coco(Dataset): ...@@ -69,14 +91,14 @@ class Coco(Dataset):
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
images = HttpResource( images = HttpResource(
f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip", f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)], sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
) )
meta = HttpResource( meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip", f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[config.year], sha256=self._META_CHECKSUMS[self._year],
) )
return [images, meta] return [images, meta]
...@@ -110,10 +132,8 @@ class Coco(Dataset): ...@@ -110,10 +132,8 @@ class Coco(Dataset):
format="xywh", format="xywh",
image_size=image_size, image_size=image_size,
), ),
labels=Label(labels, categories=self.categories), labels=Label(labels, categories=self._categories),
super_categories=[ super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
],
ann_ids=[ann["id"] for ann in anns], ann_ids=[ann["id"] for ann in anns],
) )
...@@ -134,9 +154,14 @@ class Coco(Dataset): ...@@ -134,9 +154,14 @@ class Coco(Dataset):
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json" fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
) )
def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data key, _ = data
...@@ -157,38 +182,26 @@ class Coco(Dataset): ...@@ -157,38 +182,26 @@ class Coco(Dataset):
def _prepare_sample( def _prepare_sample(
self, self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
*,
annotations: str,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
ann_data, image_data = data ann_data, image_data = data
anns, image_meta = ann_data anns, image_meta = ann_data
sample = self._prepare_image(image_data) sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample return sample
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps images_dp, meta_dp = resource_dps
if config.annotations is None: if self._annotations is None:
dp = hint_shuffling(images_dp) dp = hint_shuffling(images_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image) return Mapper(dp, self._prepare_image)
meta_dp = Filter( meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp,
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = JsonParser(meta_dp) meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
...@@ -216,7 +229,6 @@ class Coco(Dataset): ...@@ -216,7 +229,6 @@ class Coco(Dataset):
ref_key_fn=getitem("id"), ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
dp = IterKeyZipper( dp = IterKeyZipper(
anns_dp, anns_dp,
images_dp, images_dp,
...@@ -224,18 +236,24 @@ class Coco(Dataset): ...@@ -224,18 +236,24 @@ class Coco(Dataset):
ref_key_fn=path_accessor("name"), ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]
return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: resources = self._resources()
config = self.default_config
resources = self.resources(config)
dp = resources[1].load(root) dp = resources[1].load(self._root)
dp = Filter( dp = Filter(dp, self._filter_meta_files)
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = JsonParser(dp) dp = JsonParser(dp)
_, meta = next(iter(dp)) _, meta = next(iter(dp))
......
import pathlib import pathlib
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import (
path_comparator,
hint_sharding,
hint_shuffling,
read_categories_file,
)
from torchvision.prototype.features import EncodedImage, Label from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "country211"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Country211(Dataset): class Country211(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"country211", """
homepage="https://github.com/openai/CLIP/blob/main/data/country211.md",
valid_options=dict(split=("train", "val", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
self._split_folder_name = "valid" if split == "val" else split
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
"https://openaipublic.azureedge.net/clip/data/country211.tgz", "https://openaipublic.azureedge.net/clip/data/country211.tgz",
...@@ -23,17 +49,11 @@ class Country211(Dataset): ...@@ -23,17 +49,11 @@ class Country211(Dataset):
) )
] ]
_SPLIT_NAME_MAPPER = {
"train": "train",
"val": "valid",
"test": "test",
}
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data path, buffer = data
category = pathlib.Path(path).parent.name category = pathlib.Path(path).parent.name
return dict( return dict(
label=Label.from_category(category, categories=self.categories), label=Label.from_category(category, categories=self._categories),
path=path, path=path,
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
) )
...@@ -41,16 +61,21 @@ class Country211(Dataset): ...@@ -41,16 +61,21 @@ class Country211(Dataset):
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
return pathlib.Path(data[0]).parent.parent.name == split return pathlib.Path(data[0]).parent.parent.name == split
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name))
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def __len__(self) -> int:
resources = self.resources(self.default_config) return {
dp = resources[0].load(root) "train": 31_650,
"val": 10_550,
"test": 21_100,
}[self._split]
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) return sorted({pathlib.Path(path).parent.name for path, _ in dp})
import csv import csv
import functools import functools
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
...@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import ( ...@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
...@@ -27,27 +25,52 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -27,27 +25,52 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling, hint_shuffling,
getitem, getitem,
path_comparator, path_comparator,
read_categories_file,
path_accessor, path_accessor,
) )
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
from .._api import register_dataset, register_info
csv.register_dialect("cub200", delimiter=" ") csv.register_dialect("cub200", delimiter=" ")
NAME = "cub200"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class CUB200(Dataset): class CUB200(Dataset):
def _make_info(self) -> DatasetInfo: """
return DatasetInfo( - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
"cub200", """
homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html",
dependencies=("scipy",), def __init__(
valid_options=dict( self,
split=("train", "test"), root: Union[str, pathlib.Path],
year=("2011", "2010"), *,
), split: str = "train",
year: str = "2011",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._year = self._verify_str_arg(year, "year", ("2010", "2011"))
self._categories = _info()["categories"]
super().__init__(
root,
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
# dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
) )
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
if config.year == "2011": if self._year == "2011":
archive = HttpResource( archive = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
...@@ -59,7 +82,7 @@ class CUB200(Dataset): ...@@ -59,7 +82,7 @@ class CUB200(Dataset):
preprocess="decompress", preprocess="decompress",
) )
return [archive, segmentations] return [archive, segmentations]
else: # config.year == "2010" else: # self._year == "2010"
split = HttpResource( split = HttpResource(
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
...@@ -90,12 +113,12 @@ class CUB200(Dataset): ...@@ -90,12 +113,12 @@ class CUB200(Dataset):
else: else:
return None return None
def _2011_filter_split(self, row: List[str], *, split: str) -> bool: def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row _, split_id = row
return { return {
"0": "test", "0": "test",
"1": "train", "1": "train",
}[split_id] == split }[split_id] == self._split
def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
...@@ -149,17 +172,12 @@ class CUB200(Dataset): ...@@ -149,17 +172,12 @@ class CUB200(Dataset):
return dict( return dict(
prepare_ann_fn(anns_data, image.image_size), prepare_ann_fn(anns_data, image.image_size),
image=image, image=image,
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
prepare_ann_fn: Callable prepare_ann_fn: Callable
if config.year == "2011": if self._year == "2011":
archive_dp, segmentations_dp = resource_dps archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
...@@ -171,7 +189,7 @@ class CUB200(Dataset): ...@@ -171,7 +189,7 @@ class CUB200(Dataset):
) )
split_dp = CSVParser(split_dp, dialect="cub200") split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split)) split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0)) split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.get) split_dp = Mapper(split_dp, image_files_map.get)
...@@ -188,10 +206,10 @@ class CUB200(Dataset): ...@@ -188,10 +206,10 @@ class CUB200(Dataset):
) )
prepare_ann_fn = self._2011_prepare_ann prepare_ann_fn = self._2011_prepare_ann
else: # config.year == "2010" else: # self._year == "2010"
split_dp, images_dp, anns_dp = resource_dps split_dp, images_dp, anns_dp = resource_dps
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key) split_dp = Mapper(split_dp, self._2010_split_key)
...@@ -217,11 +235,19 @@ class CUB200(Dataset): ...@@ -217,11 +235,19 @@ class CUB200(Dataset):
) )
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def __len__(self) -> int:
config = self.info.make_config(year="2011") return {
resources = self.resources(config) ("train", "2010"): 3_000,
("test", "2010"): 3_033,
("train", "2011"): 5_994,
("test", "2011"): 5_794,
}[(self._split, self._year)]
def _generate_categories(self) -> List[str]:
self._year = "2011"
resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt")) dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200") dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")
......
import enum import enum
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
HttpResource, HttpResource,
OnlineResource, OnlineResource,
) )
...@@ -15,10 +13,16 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -15,10 +13,16 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
path_comparator, path_comparator,
getitem, getitem,
read_categories_file,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "dtd"
class DTDDemux(enum.IntEnum): class DTDDemux(enum.IntEnum):
SPLIT = 0 SPLIT = 0
...@@ -26,18 +30,36 @@ class DTDDemux(enum.IntEnum): ...@@ -26,18 +30,36 @@ class DTDDemux(enum.IntEnum):
IMAGES = 2 IMAGES = 2
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class DTD(Dataset): class DTD(Dataset):
def _make_info(self) -> DatasetInfo: """DTD Dataset.
return DatasetInfo( homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
"dtd", """
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
fold=tuple(str(fold) for fold in range(1, 11)),
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
fold: int = 1,
skip_validation_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
if not (1 <= fold <= 10):
raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}")
self._fold = fold
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_validation_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource( archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
...@@ -71,24 +93,19 @@ class DTD(Dataset): ...@@ -71,24 +93,19 @@ class DTD(Dataset):
return dict( return dict(
joint_categories={category for category in joint_categories if category}, joint_categories={category for category in joint_categories if category},
label=Label.from_category(category, categories=self.categories), label=Label.from_category(category, categories=self._categories),
path=path, path=path,
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
splits_dp, joint_categories_dp, images_dp = Demultiplexer( splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
) )
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False) splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = hint_shuffling(splits_dp) splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp) splits_dp = hint_sharding(splits_dp)
...@@ -114,10 +131,13 @@ class DTD(Dataset): ...@@ -114,10 +131,13 @@ class DTD(Dataset):
def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES return self._classify_archive(data) == DTDDemux.IMAGES
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self) -> List[str]:
resources = self.resources(self.default_config) resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, self._filter_images) dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) return sorted({pathlib.Path(path).parent.name for path, _ in dp})
def __len__(self) -> int:
return 1_880 # All splits have the same length
import pathlib import pathlib
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedImage, Label from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
class EuroSAT(Dataset): NAME = "eurosat"
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"eurosat", @register_info(NAME)
homepage="https://github.com/phelber/eurosat", def _info() -> Dict[str, Any]:
categories=( return dict(
"AnnualCrop", categories=(
"Forest", "AnnualCrop",
"HerbaceousVegetation", "Forest",
"Highway", "HerbaceousVegetation",
"Industrial," "Pasture", "Highway",
"PermanentCrop", "Industrial," "Pasture",
"Residential", "PermanentCrop",
"River", "Residential",
"SeaLake", "River",
), "SeaLake",
) )
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]: @register_dataset(NAME)
class EuroSAT(Dataset):
"""EuroSAT Dataset.
homepage="https://github.com/phelber/eurosat",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip", "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
...@@ -37,15 +50,16 @@ class EuroSAT(Dataset): ...@@ -37,15 +50,16 @@ class EuroSAT(Dataset):
path, buffer = data path, buffer = data
category = pathlib.Path(path).parent.name category = pathlib.Path(path).parent.name
return dict( return dict(
label=Label.from_category(category, categories=self.categories), label=Label.from_category(category, categories=self._categories),
path=path, path=path,
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 27_000
from typing import Any, Dict, List, cast import pathlib
from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource, OnlineResource,
KaggleDownloadResource, KaggleDownloadResource,
) )
...@@ -15,26 +14,40 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -15,26 +14,40 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import Label, Image from torchvision.prototype.features import Label, Image
from .._api import register_dataset, register_info
NAME = "fer2013"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"))
@register_dataset(NAME)
class FER2013(Dataset): class FER2013(Dataset):
def _make_info(self) -> DatasetInfo: """FER 2013 Dataset
return DatasetInfo( homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"fer2013", """
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), def __init__(
valid_options=dict(split=("train", "test")), self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) ) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_CHECKSUMS = { _CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource( archive = KaggleDownloadResource(
cast(str, self.info.homepage), "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{config.split}.csv.zip", file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[config.split], sha256=self._CHECKSUMS[self._split],
) )
return [archive] return [archive]
...@@ -43,17 +56,15 @@ class FER2013(Dataset): ...@@ -43,17 +56,15 @@ class FER2013(Dataset):
return dict( return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self.categories) if label_id is not None else None, label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
) )
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVDictParser(dp) dp = CSVDictParser(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 28_709 if self._split == "train" else 3_589
from pathlib import Path from pathlib import Path
from typing import Any, Tuple, List, Dict, Optional, BinaryIO from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
...@@ -9,26 +9,41 @@ from torchdata.datapipes.iter import ( ...@@ -9,26 +9,41 @@ from torchdata.datapipes.iter import (
Demultiplexer, Demultiplexer,
IterKeyZipper, IterKeyZipper,
) )
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_shuffling, hint_shuffling,
hint_sharding, hint_sharding,
path_comparator, path_comparator,
getitem, getitem,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
read_categories_file,
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
from .._api import register_dataset, register_info
NAME = "food101"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Food101(Dataset): class Food101(Dataset):
def _make_info(self) -> DatasetInfo: """Food 101 dataset
return DatasetInfo( homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
"food101", """
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
valid_options=dict(split=("train", "test")), def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None:
) self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
def resources(self, config: DatasetConfig) -> List[OnlineResource]: super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [ return [
HttpResource( HttpResource(
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
...@@ -49,7 +64,7 @@ class Food101(Dataset): ...@@ -49,7 +64,7 @@ class Food101(Dataset):
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data id, (path, buffer) = data
return dict( return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self.categories), label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
path=path, path=path,
image=EncodedImage.from_file(buffer), image=EncodedImage.from_file(buffer),
) )
...@@ -58,17 +73,12 @@ class Food101(Dataset): ...@@ -58,17 +73,12 @@ class Food101(Dataset):
path = Path(data[0]) path = Path(data[0])
return path.relative_to(path.parents[1]).with_suffix("").as_posix() return path.relative_to(path.parents[1]).with_suffix("").as_posix()
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
images_dp, split_dp = Demultiplexer( images_dp, split_dp = Demultiplexer(
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
) )
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
...@@ -83,9 +93,12 @@ class Food101(Dataset): ...@@ -83,9 +93,12 @@ class Food101(Dataset):
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: Path) -> List[str]: def _generate_categories(self) -> List[str]:
resources = self.resources(self.default_config) resources = self._resources()
dp = resources[0].load(root) dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt")) dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = LineReader(dp, decode=True, return_path=False) dp = LineReader(dp, decode=True, return_path=False)
return list(dp) return list(dp)
def __len__(self) -> int:
return 75_750 if self._split == "train" else 25_250
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource, OnlineResource,
HttpResource, HttpResource,
) )
...@@ -17,15 +15,31 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -17,15 +15,31 @@ from torchvision.prototype.datasets.utils._internal import (
) )
from torchvision.prototype.features import Label, BoundingBox, EncodedImage from torchvision.prototype.features import Label, BoundingBox, EncodedImage
from .._api import register_dataset, register_info
NAME = "gtsrb"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
categories=[f"{label:05d}" for label in range(43)],
)
@register_dataset(NAME)
class GTSRB(Dataset): class GTSRB(Dataset):
def _make_info(self) -> DatasetInfo: """GTSRB Dataset
return DatasetInfo(
"gtsrb", homepage="https://benchmark.ini.rub.de"
homepage="https://benchmark.ini.rub.de", """
categories=[f"{label:05d}" for label in range(43)],
valid_options=dict(split=("train", "test")), def __init__(
) self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
_URLS = { _URLS = {
...@@ -39,10 +53,10 @@ class GTSRB(Dataset): ...@@ -39,10 +53,10 @@ class GTSRB(Dataset):
"test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d",
} }
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])] rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])]
if config.split == "test": if self._split == "test":
rsrcs.append( rsrcs.append(
HttpResource( HttpResource(
self._URLS["test_ground_truth"], self._URLS["test_ground_truth"],
...@@ -74,14 +88,12 @@ class GTSRB(Dataset): ...@@ -74,14 +88,12 @@ class GTSRB(Dataset):
return { return {
"path": path, "path": path,
"image": EncodedImage.from_file(buffer), "image": EncodedImage.from_file(buffer),
"label": Label(label, categories=self.categories), "label": Label(label, categories=self._categories),
"bounding_box": bounding_box, "bounding_box": bounding_box,
} }
def _make_datapipe( def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig if self._split == "train":
) -> IterDataPipe[Dict[str, Any]]:
if config.split == "train":
images_dp, ann_dp = Demultiplexer( images_dp, ann_dp = Demultiplexer(
resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
) )
...@@ -98,3 +110,6 @@ class GTSRB(Dataset): ...@@ -98,3 +110,6 @@ class GTSRB(Dataset):
dp = hint_sharding(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 26_640 if self._split == "train" else 12_630
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment