Unverified Commit 3e79d149 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor prototype datasets tests (#5136)

* refactor prototype datasets tests

* skip tests with insufficient third party dependencies
parent a47c46cb
import contextlib
import functools
import gzip
import itertools
import json
import lzma
import pathlib
import pickle
import tempfile
from collections import defaultdict
from typing import Any, Dict, Tuple
from collections import defaultdict, UserList
import numpy as np
import PIL.Image
import pytest
import torch
from datasets_utils import create_image_folder, make_tar, make_zip
from datasets_utils import make_zip, make_tar, create_image_folder
from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import add_suggestion
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
__all__ = ["load"]
TEST_HOME = pathlib.Path(tempfile.mkdtemp())
DEFAULT_TEST_DECODER = object()
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
class TestResource(datasets.utils.OnlineResource):
class ResourceMock(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs)
self.dataset_name = dataset_name
......@@ -42,96 +40,106 @@ class TestResource(datasets.utils.OnlineResource):
)
class DatasetMocks:
def __init__(self):
self._mock_data_fns = {}
self._tmp_home = pathlib.Path(tempfile.mkdtemp())
class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None):
self.dataset = find(name)
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
self.configs = configs or self.info._configs
self._cache = {}
def register_mock_data_fn(self, mock_data_fn):
name = mock_data_fn.__name__
if name not in datasets.list():
raise pytest.UsageError(
add_suggestion(
f"The name of the mock data function '{name}' has no corresponding dataset.",
word=name,
possibilities=datasets.list(),
close_match_hint=lambda close_match: f"Did you mean to name it '{close_match}'?",
alternative_hint=lambda _: "",
)
)
self._mock_data_fns[name] = mock_data_fn
return mock_data_fn
def _parse_mock_info(self, mock_info, *, name):
if mock_info is None:
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned nothing. It needs to at least return an integer "
f"indicating the number of samples for the current `config`."
)
elif isinstance(mock_info, int):
mock_info = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned a {type(mock_info)}. The returned object should "
f"be a dictionary containing at least the number of samples for the current `config` for the key "
f"`'num_samples'`. If no additional information is required for specific tests, the number of samples "
f"can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{name}' must contain a `'num_samples'` "
f"entry indicating the number of samples for the current `config`."
)
return mock_info
@property
def info(self):
return self.dataset.info
def _get(self, dataset, config, root):
name = dataset.info.name
resources_and_mock_info = self._cache.get((name, config))
if resources_and_mock_info:
return resources_and_mock_info
@property
def name(self):
return self.info.name
try:
fakedata_fn = self._mock_data_fns[name]
except KeyError:
raise pytest.UsageError(
f"No mock data available for dataset '{name}'. "
f"Did you add a new dataset, but forget to provide mock data for it? "
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
)
def _parse_mock_data(self, mock_data_fn):
def wrapper(info, root, config):
mock_infos = mock_data_fn(info, root, config)
mock_resources = [
TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name)
for resource in dataset.resources(config)
]
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info
if mock_infos is None:
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`."
)
def load(
self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any
) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(split=split, **options)
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
root = self._tmp_home / name
root.mkdir(exist_ok=True)
resources, mock_info = self._get(dataset, config, root)
for config_, mock_info in list(mock_infos.items()):
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
)
return mock_infos
return wrapper
def _load_mock(self, config):
with contextlib.suppress(KeyError):
return self._cache[config]
self.root.mkdir(exist_ok=True)
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
mock_resources = [
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
for resource in self.dataset.resources(config_)
]
self._cache[config_] = (mock_resources, mock_info)
return self._cache[config]
def load(self, config, *, decoder=DEFAULT_DECODER):
try:
self.info.check_dependencies()
except ModuleNotFoundError as error:
pytest.skip(str(error))
datapipe = dataset._make_datapipe(
[resource.load(root) for resource in resources],
mock_resources, mock_info = self._load_mock(config)
datapipe = self.dataset._make_datapipe(
[resource.load(self.root) for resource in mock_resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
)
return datapipe, mock_info
dataset_mocks = DatasetMocks()
load = dataset_mocks.load
class DatasetMocks(UserList):
def append_named_callable(self, fn):
mock_data_fn = fn.__func__ if isinstance(fn, classmethod) else fn
self.data.append(DatasetMock(mock_data_fn.__name__, mock_data_fn))
return fn
DATASET_MOCKS = DatasetMocks()
class MNISTFakedata:
class MNISTMockData:
_DTYPES_ID = {
torch.uint8: 8,
torch.int8: 9,
......@@ -206,12 +214,12 @@ class MNISTFakedata:
return num_samples
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
return MNISTMockData.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
......@@ -219,60 +227,39 @@ def mnist(info, root, config):
)
@dataset_mocks.register_mock_data_fn
def fashionmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn
def kmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
DATASET_MOCKS.extend([DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]])
@dataset_mocks.register_mock_data_fn
def emnist(info, root, config):
@DATASET_MOCKS.append_named_callable
def emnist(info, root, _):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict(
lambda: len(info.categories), **{image_set: 47 for image_set in ("Balanced", "By_Merge")}
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
)
num_samples = {}
mock_infos = {}
file_names = set()
for _config in info._configs:
prefix = f"emnist-{_config.image_set.replace('_', '').lower()}-{_config.split}"
for config in info._configs:
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file})
num_samples[_config.image_set] = MNISTFakedata.generate(
root,
num_categories=num_categories[_config.image_set],
images_file=images_file,
labels_file=labels_file,
mock_infos[config] = dict(
num_samples=MNISTMockData.generate(
root,
num_categories=num_categories[config.image_set],
images_file=images_file,
labels_file=labels_file,
)
)
make_zip(root, "emnist-gzip.zip", *file_names)
return num_samples[config.image_set]
return mock_infos
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
......@@ -280,24 +267,27 @@ def qmnist(info, root, config):
prefix = "qmnist-train"
suffix = ".gz"
compressor = gzip.open
mock_infos = num_samples
elif config.split.startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create more
# than 10000 images for the dataset to not be empty.
num_samples = num_samples_gen = 10001
if config.split == "test10k":
num_samples = min(num_samples, 10000)
if config.split == "test50k":
num_samples -= 10000
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty.
num_samples_gen = 10001
prefix = "qmnist-test"
suffix = ".gz"
compressor = gzip.open
mock_infos = {
info.make_config(split="test"): num_samples_gen,
info.make_config(split="test10k"): min(num_samples_gen, 10_000),
info.make_config(split="test50k"): num_samples_gen - 10_000,
}
else: # config.split == "nist"
num_samples = num_samples_gen = num_categories + 3
prefix = "xnist"
suffix = ".xz"
compressor = lzma.open
mock_infos = num_samples
MNISTFakedata.generate(
MNISTMockData.generate(
root,
num_categories=num_categories,
num_samples=num_samples_gen,
......@@ -307,11 +297,10 @@ def qmnist(info, root, config):
label_dtype=torch.int32,
compressor=compressor,
)
return num_samples
return mock_infos
class CIFARFakedata:
class CIFARMockData:
NUM_PIXELS = 32 * 32 * 3
@classmethod
......@@ -349,12 +338,12 @@ class CIFARFakedata:
make_tar(root, name, folder, compression="gz")
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]
CIFARFakedata.generate(
CIFARMockData.generate(
root=root,
name="cifar-10-python.tar.gz",
folder=pathlib.Path("cifar-10-batches-py"),
......@@ -367,12 +356,12 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def cifar100(info, root, config):
train_files = ["train"]
test_files = ["test"]
CIFARFakedata.generate(
CIFARMockData.generate(
root=root,
name="cifar-100-python.tar.gz",
folder=pathlib.Path("cifar-100-python"),
......@@ -385,7 +374,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def caltech101(info, root, config):
def create_ann_file(root, name):
import scipy.io
......@@ -435,7 +424,7 @@ def caltech101(info, root, config):
return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def caltech256(info, root, config):
dir = root / "256_ObjectCategories"
num_images_per_category = 2
......@@ -455,7 +444,7 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
......@@ -610,6 +599,33 @@ class CocoMockData:
return num_samples
@dataset_mocks.register_mock_data_fn
@DATASET_MOCKS.append_named_callable
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)
return dict(
zip(
[config_ for config_ in info._configs if config_.year == config.year],
itertools.repeat(CocoMockData.generate(root, year=config.year, num_samples=5)),
)
)
def config_id(name, config):
parts = [name]
for name, value in config.items():
if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def parametrize_dataset_mocks(datasets_mocks):
return pytest.mark.parametrize(
("dataset_mock", "config"),
[
pytest.param(dataset_mock, config, id=config_id(dataset_mock.name, config))
for dataset_mock in datasets_mocks
for config in dataset_mock.configs
],
)
......@@ -867,7 +867,7 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
dir = archive.parent / archive.name.replace("".join(archive.suffixes), "")
dir = archive.with_suffix("")
if dir.exists() and dir.is_dir():
files_or_dirs = (dir,)
else:
......
import io
import builtin_dataset_mocks
import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype import transforms
from torchvision.prototype.utils._internal import sequence_to_str
def to_bytes(file):
return file.read()
def config_id(name, config):
parts = [name]
for name, value in config.items():
if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def dataset_parametrization(*names, decoder=to_bytes):
if not names:
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
names = (
"mnist",
"fashionmnist",
"kmnist",
"emnist",
"qmnist",
"cifar10",
"cifar100",
"caltech256",
"caltech101",
"imagenet",
"coco",
)
return pytest.mark.parametrize(
("dataset", "mock_info"),
[
pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config))
for name in names
for config in datasets.info(name)._configs
],
)
@parametrize_dataset_mocks(DATASET_MOCKS)
class TestCommon:
@dataset_parametrization()
def test_smoke(self, dataset, mock_info):
def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@dataset_parametrization()
def test_sample(self, dataset, mock_info):
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
try:
sample = next(iter(dataset))
except Exception as error:
......@@ -72,16 +31,18 @@ class TestCommon:
if not sample:
raise AssertionError("Sample dictionary is empty.")
@dataset_parametrization()
def test_num_samples(self, dataset, mock_info):
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)
num_samples = 0
for _ in dataset:
num_samples += 1
assert num_samples == mock_info["num_samples"]
@dataset_parametrization()
def test_decoding(self, dataset, mock_info):
def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
......@@ -89,8 +50,9 @@ class TestCommon:
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
@dataset_parametrization(decoder=DEFAULT_DECODER)
def test_no_vanilla_tensors(self, dataset, mock_info):
def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
raise AssertionError(
......@@ -98,22 +60,25 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)
@dataset_parametrization()
def test_transformable(self, dataset, mock_info):
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
next(iter(dataset.map(transforms.Identity())))
@dataset_parametrization()
def test_traversable(self, dataset, mock_info):
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
traverse(dataset)
@dataset_parametrization()
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset, mock_info, annotation_dp_type):
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)
dataset, _ = dataset_mock.load(config)
for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
break
......@@ -122,14 +87,10 @@ class TestCommon:
class TestQMNIST:
@pytest.mark.parametrize(
"dataset",
[
pytest.param(builtin_dataset_mocks.load("qmnist", split=split)[0], id=split)
for split in ("train", "test", "test10k", "test50k", "nist")
],
)
def test_extra_label(self, dataset):
@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next(iter(dataset))
for key, type in (
("nist_hsf_series", int),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment