Unverified Commit 3e79d149 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor prototype datasets tests (#5136)

* refactor prototype datasets tests

* skip tests with insufficient third party dependencies
parent a47c46cb
import contextlib
import functools import functools
import gzip import gzip
import itertools
import json import json
import lzma import lzma
import pathlib import pathlib
import pickle import pickle
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict, UserList
from typing import Any, Dict, Tuple
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from datasets_utils import create_image_folder, make_tar, make_zip from datasets_utils import make_zip, make_tar, create_image_folder
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import add_suggestion
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ()) make_scalar = functools.partial(make_tensor, ())
__all__ = ["load"] TEST_HOME = pathlib.Path(tempfile.mkdtemp())
DEFAULT_TEST_DECODER = object() __all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
class TestResource(datasets.utils.OnlineResource): class ResourceMock(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs): def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dataset_name = dataset_name self.dataset_name = dataset_name
...@@ -42,96 +40,106 @@ class TestResource(datasets.utils.OnlineResource): ...@@ -42,96 +40,106 @@ class TestResource(datasets.utils.OnlineResource):
) )
class DatasetMocks: class DatasetMock:
def __init__(self): def __init__(self, name, mock_data_fn, *, configs=None):
self._mock_data_fns = {} self.dataset = find(name)
self._tmp_home = pathlib.Path(tempfile.mkdtemp()) self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
self.configs = configs or self.info._configs
self._cache = {} self._cache = {}
def register_mock_data_fn(self, mock_data_fn): @property
name = mock_data_fn.__name__ def info(self):
if name not in datasets.list(): return self.dataset.info
raise pytest.UsageError(
add_suggestion(
f"The name of the mock data function '{name}' has no corresponding dataset.",
word=name,
possibilities=datasets.list(),
close_match_hint=lambda close_match: f"Did you mean to name it '{close_match}'?",
alternative_hint=lambda _: "",
)
)
self._mock_data_fns[name] = mock_data_fn
return mock_data_fn
def _parse_mock_info(self, mock_info, *, name):
if mock_info is None:
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned nothing. It needs to at least return an integer "
f"indicating the number of samples for the current `config`."
)
elif isinstance(mock_info, int):
mock_info = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{name}' returned a {type(mock_info)}. The returned object should "
f"be a dictionary containing at least the number of samples for the current `config` for the key "
f"`'num_samples'`. If no additional information is required for specific tests, the number of samples "
f"can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{name}' must contain a `'num_samples'` "
f"entry indicating the number of samples for the current `config`."
)
return mock_info
def _get(self, dataset, config, root): @property
name = dataset.info.name def name(self):
resources_and_mock_info = self._cache.get((name, config)) return self.info.name
if resources_and_mock_info:
return resources_and_mock_info
try: def _parse_mock_data(self, mock_data_fn):
fakedata_fn = self._mock_data_fns[name] def wrapper(info, root, config):
except KeyError: mock_infos = mock_data_fn(info, root, config)
raise pytest.UsageError(
f"No mock data available for dataset '{name}'. "
f"Did you add a new dataset, but forget to provide mock data for it? "
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
)
mock_resources = [ if mock_infos is None:
TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name) raise pytest.UsageError(
for resource in dataset.resources(config) f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
] f"integer indicating the number of samples for the current `config`."
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name) )
self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info
def load( key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any if datasets.utils.DatasetConfig not in key_types:
) -> Tuple[IterDataPipe, Dict[str, Any]]: mock_infos = {config: mock_infos}
dataset = find(name) elif len(key_types) > 1:
config = dataset.info.make_config(split=split, **options) raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
root = self._tmp_home / name for config_, mock_info in list(mock_infos.items()):
root.mkdir(exist_ok=True) if config_ in self._cache:
resources, mock_info = self._get(dataset, config, root) raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
)
return mock_infos
return wrapper
def _load_mock(self, config):
with contextlib.suppress(KeyError):
return self._cache[config]
self.root.mkdir(exist_ok=True)
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
mock_resources = [
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
for resource in self.dataset.resources(config_)
]
self._cache[config_] = (mock_resources, mock_info)
return self._cache[config]
def load(self, config, *, decoder=DEFAULT_DECODER):
try:
self.info.check_dependencies()
except ModuleNotFoundError as error:
pytest.skip(str(error))
datapipe = dataset._make_datapipe( mock_resources, mock_info = self._load_mock(config)
[resource.load(root) for resource in resources], datapipe = self.dataset._make_datapipe(
[resource.load(self.root) for resource in mock_resources],
config=config, config=config,
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder, decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
) )
return datapipe, mock_info return datapipe, mock_info
dataset_mocks = DatasetMocks() class DatasetMocks(UserList):
load = dataset_mocks.load def append_named_callable(self, fn):
mock_data_fn = fn.__func__ if isinstance(fn, classmethod) else fn
self.data.append(DatasetMock(mock_data_fn.__name__, mock_data_fn))
return fn
DATASET_MOCKS = DatasetMocks()
class MNISTFakedata:
class MNISTMockData:
_DTYPES_ID = { _DTYPES_ID = {
torch.uint8: 8, torch.uint8: 8,
torch.int8: 9, torch.int8: 9,
...@@ -206,12 +214,12 @@ class MNISTFakedata: ...@@ -206,12 +214,12 @@ class MNISTFakedata:
return num_samples return num_samples
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def mnist(info, root, config): def mnist(info, root, config):
train = config.split == "train" train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz" labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate( return MNISTMockData.generate(
root, root,
num_categories=len(info.categories), num_categories=len(info.categories),
images_file=images_file, images_file=images_file,
...@@ -219,60 +227,39 @@ def mnist(info, root, config): ...@@ -219,60 +227,39 @@ def mnist(info, root, config):
) )
@dataset_mocks.register_mock_data_fn DATASET_MOCKS.extend([DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]])
def fashionmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn
def kmnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
return MNISTFakedata.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
)
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def emnist(info, root, config): def emnist(info, root, _):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense # The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there. # labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict( num_categories = defaultdict(
lambda: len(info.categories), **{image_set: 47 for image_set in ("Balanced", "By_Merge")} lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
) )
num_samples = {} mock_infos = {}
file_names = set() file_names = set()
for _config in info._configs: for config in info._configs:
prefix = f"emnist-{_config.image_set.replace('_', '').lower()}-{_config.split}" prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}"
images_file = f"{prefix}-images-idx3-ubyte.gz" images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file}) file_names.update({images_file, labels_file})
num_samples[_config.image_set] = MNISTFakedata.generate( mock_infos[config] = dict(
root, num_samples=MNISTMockData.generate(
num_categories=num_categories[_config.image_set], root,
images_file=images_file, num_categories=num_categories[config.image_set],
labels_file=labels_file, images_file=images_file,
labels_file=labels_file,
)
) )
make_zip(root, "emnist-gzip.zip", *file_names) make_zip(root, "emnist-gzip.zip", *file_names)
return num_samples[config.image_set] return mock_infos
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def qmnist(info, root, config): def qmnist(info, root, config):
num_categories = len(info.categories) num_categories = len(info.categories)
if config.split == "train": if config.split == "train":
...@@ -280,24 +267,27 @@ def qmnist(info, root, config): ...@@ -280,24 +267,27 @@ def qmnist(info, root, config):
prefix = "qmnist-train" prefix = "qmnist-train"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
mock_infos = num_samples
elif config.split.startswith("test"): elif config.split.startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create more # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# than 10000 images for the dataset to not be empty. # more than 10000 images for the dataset to not be empty.
num_samples = num_samples_gen = 10001 num_samples_gen = 10001
if config.split == "test10k":
num_samples = min(num_samples, 10000)
if config.split == "test50k":
num_samples -= 10000
prefix = "qmnist-test" prefix = "qmnist-test"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
mock_infos = {
info.make_config(split="test"): num_samples_gen,
info.make_config(split="test10k"): min(num_samples_gen, 10_000),
info.make_config(split="test50k"): num_samples_gen - 10_000,
}
else: # config.split == "nist" else: # config.split == "nist"
num_samples = num_samples_gen = num_categories + 3 num_samples = num_samples_gen = num_categories + 3
prefix = "xnist" prefix = "xnist"
suffix = ".xz" suffix = ".xz"
compressor = lzma.open compressor = lzma.open
mock_infos = num_samples
MNISTFakedata.generate( MNISTMockData.generate(
root, root,
num_categories=num_categories, num_categories=num_categories,
num_samples=num_samples_gen, num_samples=num_samples_gen,
...@@ -307,11 +297,10 @@ def qmnist(info, root, config): ...@@ -307,11 +297,10 @@ def qmnist(info, root, config):
label_dtype=torch.int32, label_dtype=torch.int32,
compressor=compressor, compressor=compressor,
) )
return mock_infos
return num_samples
class CIFARFakedata: class CIFARMockData:
NUM_PIXELS = 32 * 32 * 3 NUM_PIXELS = 32 * 32 * 3
@classmethod @classmethod
...@@ -349,12 +338,12 @@ class CIFARFakedata: ...@@ -349,12 +338,12 @@ class CIFARFakedata:
make_tar(root, name, folder, compression="gz") make_tar(root, name, folder, compression="gz")
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def cifar10(info, root, config): def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)] train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"] test_files = ["test_batch"]
CIFARFakedata.generate( CIFARMockData.generate(
root=root, root=root,
name="cifar-10-python.tar.gz", name="cifar-10-python.tar.gz",
folder=pathlib.Path("cifar-10-batches-py"), folder=pathlib.Path("cifar-10-batches-py"),
...@@ -367,12 +356,12 @@ def cifar10(info, root, config): ...@@ -367,12 +356,12 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def cifar100(info, root, config): def cifar100(info, root, config):
train_files = ["train"] train_files = ["train"]
test_files = ["test"] test_files = ["test"]
CIFARFakedata.generate( CIFARMockData.generate(
root=root, root=root,
name="cifar-100-python.tar.gz", name="cifar-100-python.tar.gz",
folder=pathlib.Path("cifar-100-python"), folder=pathlib.Path("cifar-100-python"),
...@@ -385,7 +374,7 @@ def cifar100(info, root, config): ...@@ -385,7 +374,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def caltech101(info, root, config): def caltech101(info, root, config):
def create_ann_file(root, name): def create_ann_file(root, name):
import scipy.io import scipy.io
...@@ -435,7 +424,7 @@ def caltech101(info, root, config): ...@@ -435,7 +424,7 @@ def caltech101(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def caltech256(info, root, config): def caltech256(info, root, config):
dir = root / "256_ObjectCategories" dir = root / "256_ObjectCategories"
num_images_per_category = 2 num_images_per_category = 2
...@@ -455,7 +444,7 @@ def caltech256(info, root, config): ...@@ -455,7 +444,7 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def imagenet(info, root, config): def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys()) wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train": if config.split == "train":
...@@ -610,6 +599,33 @@ class CocoMockData: ...@@ -610,6 +599,33 @@ class CocoMockData:
return num_samples return num_samples
@dataset_mocks.register_mock_data_fn @DATASET_MOCKS.append_named_callable
def coco(info, root, config): def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5) return dict(
zip(
[config_ for config_ in info._configs if config_.year == config.year],
itertools.repeat(CocoMockData.generate(root, year=config.year, num_samples=5)),
)
)
def config_id(name, config):
parts = [name]
for name, value in config.items():
if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def parametrize_dataset_mocks(datasets_mocks):
return pytest.mark.parametrize(
("dataset_mock", "config"),
[
pytest.param(dataset_mock, config, id=config_id(dataset_mock.name, config))
for dataset_mock in datasets_mocks
for config in dataset_mock.configs
],
)
...@@ -867,7 +867,7 @@ def _split_files_or_dirs(root, *files_or_dirs): ...@@ -867,7 +867,7 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name archive = pathlib.Path(root) / name
if not files_or_dirs: if not files_or_dirs:
dir = archive.parent / archive.name.replace("".join(archive.suffixes), "") dir = archive.with_suffix("")
if dir.exists() and dir.is_dir(): if dir.exists() and dir.is_dir():
files_or_dirs = (dir,) files_or_dirs = (dir,)
else: else:
......
import io import io
import builtin_dataset_mocks
import pytest import pytest
import torch import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import datasets, transforms from torchvision.prototype import transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
def to_bytes(file): @parametrize_dataset_mocks(DATASET_MOCKS)
return file.read()
def config_id(name, config):
parts = [name]
for name, value in config.items():
if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def dataset_parametrization(*names, decoder=to_bytes):
if not names:
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
names = (
"mnist",
"fashionmnist",
"kmnist",
"emnist",
"qmnist",
"cifar10",
"cifar100",
"caltech256",
"caltech101",
"imagenet",
"coco",
)
return pytest.mark.parametrize(
("dataset", "mock_info"),
[
pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config))
for name in names
for config in datasets.info(name)._configs
],
)
class TestCommon: class TestCommon:
@dataset_parametrization() def test_smoke(self, dataset_mock, config):
def test_smoke(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@dataset_parametrization() def test_sample(self, dataset_mock, config):
def test_sample(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
try: try:
sample = next(iter(dataset)) sample = next(iter(dataset))
except Exception as error: except Exception as error:
...@@ -72,16 +31,18 @@ class TestCommon: ...@@ -72,16 +31,18 @@ class TestCommon:
if not sample: if not sample:
raise AssertionError("Sample dictionary is empty.") raise AssertionError("Sample dictionary is empty.")
@dataset_parametrization() def test_num_samples(self, dataset_mock, config):
def test_num_samples(self, dataset, mock_info): dataset, mock_info = dataset_mock.load(config)
num_samples = 0 num_samples = 0
for _ in dataset: for _ in dataset:
num_samples += 1 num_samples += 1
assert num_samples == mock_info["num_samples"] assert num_samples == mock_info["num_samples"]
@dataset_parametrization() def test_decoding(self, dataset_mock, config):
def test_decoding(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features: if undecoded_features:
raise AssertionError( raise AssertionError(
...@@ -89,8 +50,9 @@ class TestCommon: ...@@ -89,8 +50,9 @@ class TestCommon:
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
) )
@dataset_parametrization(decoder=DEFAULT_DECODER) def test_no_vanilla_tensors(self, dataset_mock, config):
def test_no_vanilla_tensors(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors: if vanilla_tensors:
raise AssertionError( raise AssertionError(
...@@ -98,22 +60,25 @@ class TestCommon: ...@@ -98,22 +60,25 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
) )
@dataset_parametrization() def test_transformable(self, dataset_mock, config):
def test_transformable(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
next(iter(dataset.map(transforms.Identity()))) next(iter(dataset.map(transforms.Identity())))
@dataset_parametrization() def test_traversable(self, dataset_mock, config):
def test_traversable(self, dataset, mock_info): dataset, _ = dataset_mock.load(config)
traverse(dataset) traverse(dataset)
@dataset_parametrization()
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset, mock_info, annotation_dp_type): def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph): def scan(graph):
for node, sub_graph in graph.items(): for node, sub_graph in graph.items():
yield node yield node
yield from scan(sub_graph) yield from scan(sub_graph)
dataset, _ = dataset_mock.load(config)
for dp in scan(traverse(dataset)): for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type: if type(dp) is annotation_dp_type:
break break
...@@ -122,14 +87,10 @@ class TestCommon: ...@@ -122,14 +87,10 @@ class TestCommon:
class TestQMNIST: class TestQMNIST:
@pytest.mark.parametrize( @parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])
"dataset", def test_extra_label(self, dataset_mock, config):
[ dataset, _ = dataset_mock.load(config)
pytest.param(builtin_dataset_mocks.load("qmnist", split=split)[0], id=split)
for split in ("train", "test", "test10k", "test50k", "nist")
],
)
def test_extra_label(self, dataset):
sample = next(iter(dataset)) sample = next(iter(dataset))
for key, type in ( for key, type in (
("nist_hsf_series", int), ("nist_hsf_series", int),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment