Unverified Commit ed84880b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor simplifications to prototype dataset testings (#5268)

parent b21e0bfb
......@@ -12,7 +12,7 @@ import random
import tempfile
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict
from collections import defaultdict, Counter
import numpy as np
import PIL.Image
......@@ -34,35 +34,17 @@ TEST_HOME = pathlib.Path(tempfile.mkdtemp())
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
class ResourceMock(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs)
self.dataset_name = dataset_name
self.dataset_config = dataset_config
def _download(self, _):
raise pytest.UsageError(
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
f"but this file does not exist."
)
class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None):
def __init__(self, name, mock_data_fn):
self.dataset = find(name)
self.info = self.dataset.info
self.name = self.info.name
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = mock_data_fn
self.configs = configs or self.info._configs
self.configs = self.info._configs
self._cache = {}
@property
def info(self):
return self.dataset.info
@property
def name(self):
return self.info.name
def _parse_mock_data(self, config, mock_infos):
if mock_infos is None:
raise pytest.UsageError(
......@@ -79,7 +61,7 @@ class DatasetMock:
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
for config_, mock_info in list(mock_infos.items()):
for config_, mock_info in mock_infos.items():
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
......@@ -103,7 +85,7 @@ class DatasetMock:
return mock_infos
def _prepare_resources(self, config):
with contextlib.suppress(KeyError):
if config in self._cache:
return self._cache[config]
self.root.mkdir(exist_ok=True)
......@@ -146,8 +128,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
for mock in dataset_mocks:
if isinstance(mock, DatasetMock):
mocks[mock.name] = mock
elif isinstance(mock, collections.abc.Sequence):
mocks.update({mock_.name: mock_ for mock_ in mock})
elif isinstance(mock, collections.abc.Mapping):
mocks.update(mock)
else:
......@@ -173,14 +153,13 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
)
class DatasetMocks(UserDict):
def set_from_named_callable(self, fn):
name = fn.__name__.replace("_", "-")
self.data[name] = DatasetMock(name, fn)
return fn
DATASET_MOCKS = {}
DATASET_MOCKS = DatasetMocks()
def register_mock(fn):
name = fn.__name__.replace("_", "-")
DATASET_MOCKS[name] = DatasetMock(name, fn)
return fn
class MNISTMockData:
......@@ -258,7 +237,7 @@ class MNISTMockData:
return num_samples
@DATASET_MOCKS.set_from_named_callable
@register_mock
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
......@@ -274,7 +253,7 @@ def mnist(info, root, config):
DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
@DATASET_MOCKS.set_from_named_callable
@register_mock
def emnist(info, root, _):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
......@@ -303,7 +282,7 @@ def emnist(info, root, _):
return mock_infos
@DATASET_MOCKS.set_from_named_callable
@register_mock
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
......@@ -382,7 +361,7 @@ class CIFARMockData:
make_tar(root, name, folder, compression="gz")
@DATASET_MOCKS.set_from_named_callable
@register_mock
def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]
......@@ -400,7 +379,7 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.set_from_named_callable
@register_mock
def cifar100(info, root, config):
train_files = ["train"]
test_files = ["test"]
......@@ -418,7 +397,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.set_from_named_callable
@register_mock
def caltech101(info, root, config):
def create_ann_file(root, name):
import scipy.io
......@@ -468,7 +447,7 @@ def caltech101(info, root, config):
return num_images_per_category * len(info.categories)
@DATASET_MOCKS.set_from_named_callable
@register_mock
def caltech256(info, root, config):
dir = root / "256_ObjectCategories"
num_images_per_category = 2
......@@ -488,7 +467,7 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories)
@DATASET_MOCKS.set_from_named_callable
@register_mock
def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
......@@ -643,7 +622,7 @@ class CocoMockData:
return num_samples
@DATASET_MOCKS.set_from_named_callable
@register_mock
def coco(info, root, config):
return dict(
zip(
......@@ -722,13 +701,13 @@ class SBDMockData:
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def sbd(info, root, _):
num_samples_map = SBDMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable
@register_mock
def semeion(info, root, config):
num_samples = 3
......@@ -839,7 +818,7 @@ class VOCMockData:
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def voc(info, root, config):
trainval = config.split != "test"
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
......@@ -938,13 +917,13 @@ class CelebAMockData:
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def celeba(info, root, _):
num_samples_map = CelebAMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable
@register_mock
def dtd(info, root, _):
data_folder = root / "dtd"
......@@ -992,7 +971,7 @@ def dtd(info, root, _):
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3
......@@ -1017,7 +996,7 @@ def fer2013(info, root, config):
return num_samples
@DATASET_MOCKS.set_from_named_callable
@register_mock
def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012")
......@@ -1087,7 +1066,7 @@ def gtsrb(info, root, config):
return num_examples
@DATASET_MOCKS.set_from_named_callable
@register_mock
def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0"
......@@ -1193,7 +1172,7 @@ class OxfordIIITPetMockData:
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def oxford_iiit_pet(info, root, config):
num_samples_map = OxfordIIITPetMockData.generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs}
......@@ -1360,13 +1339,13 @@ class CUB2002010MockData(_CUB200MockData):
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
@register_mock
def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
@DATASET_MOCKS.set_from_named_callable
@register_mock
def svhn(info, root, config):
import scipy.io as sio
......
......@@ -110,7 +110,7 @@ class TestCommon:
)
},
)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
......@@ -120,10 +120,7 @@ class TestCommon:
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
break
else:
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment