Unverified Commit ed84880b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor simplifications to prototype dataset testings (#5268)

parent b21e0bfb
...@@ -12,7 +12,7 @@ import random ...@@ -12,7 +12,7 @@ import random
import tempfile import tempfile
import unittest.mock import unittest.mock
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict from collections import defaultdict, Counter
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -34,35 +34,17 @@ TEST_HOME = pathlib.Path(tempfile.mkdtemp()) ...@@ -34,35 +34,17 @@ TEST_HOME = pathlib.Path(tempfile.mkdtemp())
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"] __all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
class ResourceMock(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs)
self.dataset_name = dataset_name
self.dataset_config = dataset_config
def _download(self, _):
raise pytest.UsageError(
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
f"but this file does not exist."
)
class DatasetMock: class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None): def __init__(self, name, mock_data_fn):
self.dataset = find(name) self.dataset = find(name)
self.info = self.dataset.info
self.name = self.info.name
self.root = TEST_HOME / self.dataset.name self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = mock_data_fn self.mock_data_fn = mock_data_fn
self.configs = configs or self.info._configs self.configs = self.info._configs
self._cache = {} self._cache = {}
@property
def info(self):
return self.dataset.info
@property
def name(self):
return self.info.name
def _parse_mock_data(self, config, mock_infos): def _parse_mock_data(self, config, mock_infos):
if mock_infos is None: if mock_infos is None:
raise pytest.UsageError( raise pytest.UsageError(
...@@ -79,7 +61,7 @@ class DatasetMock: ...@@ -79,7 +61,7 @@ class DatasetMock:
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
) )
for config_, mock_info in list(mock_infos.items()): for config_, mock_info in mock_infos.items():
if config_ in self._cache: if config_ in self._cache:
raise pytest.UsageError( raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} " f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
...@@ -103,7 +85,7 @@ class DatasetMock: ...@@ -103,7 +85,7 @@ class DatasetMock:
return mock_infos return mock_infos
def _prepare_resources(self, config): def _prepare_resources(self, config):
with contextlib.suppress(KeyError): if config in self._cache:
return self._cache[config] return self._cache[config]
self.root.mkdir(exist_ok=True) self.root.mkdir(exist_ok=True)
...@@ -146,8 +128,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): ...@@ -146,8 +128,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
for mock in dataset_mocks: for mock in dataset_mocks:
if isinstance(mock, DatasetMock): if isinstance(mock, DatasetMock):
mocks[mock.name] = mock mocks[mock.name] = mock
elif isinstance(mock, collections.abc.Sequence):
mocks.update({mock_.name: mock_ for mock_ in mock})
elif isinstance(mock, collections.abc.Mapping): elif isinstance(mock, collections.abc.Mapping):
mocks.update(mock) mocks.update(mock)
else: else:
...@@ -173,14 +153,13 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): ...@@ -173,14 +153,13 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
) )
class DatasetMocks(UserDict): DATASET_MOCKS = {}
def set_from_named_callable(self, fn):
name = fn.__name__.replace("_", "-")
self.data[name] = DatasetMock(name, fn)
return fn
DATASET_MOCKS = DatasetMocks() def register_mock(fn):
name = fn.__name__.replace("_", "-")
DATASET_MOCKS[name] = DatasetMock(name, fn)
return fn
class MNISTMockData: class MNISTMockData:
...@@ -258,7 +237,7 @@ class MNISTMockData: ...@@ -258,7 +237,7 @@ class MNISTMockData:
return num_samples return num_samples
@DATASET_MOCKS.set_from_named_callable @register_mock
def mnist(info, root, config): def mnist(info, root, config):
train = config.split == "train" train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
...@@ -274,7 +253,7 @@ def mnist(info, root, config): ...@@ -274,7 +253,7 @@ def mnist(info, root, config):
DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
@DATASET_MOCKS.set_from_named_callable @register_mock
def emnist(info, root, _): def emnist(info, root, _):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense # The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there. # labels in the data files. Thus, num_categories != len(categories) there.
...@@ -303,7 +282,7 @@ def emnist(info, root, _): ...@@ -303,7 +282,7 @@ def emnist(info, root, _):
return mock_infos return mock_infos
@DATASET_MOCKS.set_from_named_callable @register_mock
def qmnist(info, root, config): def qmnist(info, root, config):
num_categories = len(info.categories) num_categories = len(info.categories)
if config.split == "train": if config.split == "train":
...@@ -382,7 +361,7 @@ class CIFARMockData: ...@@ -382,7 +361,7 @@ class CIFARMockData:
make_tar(root, name, folder, compression="gz") make_tar(root, name, folder, compression="gz")
@DATASET_MOCKS.set_from_named_callable @register_mock
def cifar10(info, root, config): def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)] train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"] test_files = ["test_batch"]
...@@ -400,7 +379,7 @@ def cifar10(info, root, config): ...@@ -400,7 +379,7 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.set_from_named_callable @register_mock
def cifar100(info, root, config): def cifar100(info, root, config):
train_files = ["train"] train_files = ["train"]
test_files = ["test"] test_files = ["test"]
...@@ -418,7 +397,7 @@ def cifar100(info, root, config): ...@@ -418,7 +397,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.set_from_named_callable @register_mock
def caltech101(info, root, config): def caltech101(info, root, config):
def create_ann_file(root, name): def create_ann_file(root, name):
import scipy.io import scipy.io
...@@ -468,7 +447,7 @@ def caltech101(info, root, config): ...@@ -468,7 +447,7 @@ def caltech101(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@DATASET_MOCKS.set_from_named_callable @register_mock
def caltech256(info, root, config): def caltech256(info, root, config):
dir = root / "256_ObjectCategories" dir = root / "256_ObjectCategories"
num_images_per_category = 2 num_images_per_category = 2
...@@ -488,7 +467,7 @@ def caltech256(info, root, config): ...@@ -488,7 +467,7 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@DATASET_MOCKS.set_from_named_callable @register_mock
def imagenet(info, root, config): def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys()) wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train": if config.split == "train":
...@@ -643,7 +622,7 @@ class CocoMockData: ...@@ -643,7 +622,7 @@ class CocoMockData:
return num_samples return num_samples
@DATASET_MOCKS.set_from_named_callable @register_mock
def coco(info, root, config): def coco(info, root, config):
return dict( return dict(
zip( zip(
...@@ -722,13 +701,13 @@ class SBDMockData: ...@@ -722,13 +701,13 @@ class SBDMockData:
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def sbd(info, root, _): def sbd(info, root, _):
num_samples_map = SBDMockData.generate(root) num_samples_map = SBDMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs} return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable @register_mock
def semeion(info, root, config): def semeion(info, root, config):
num_samples = 3 num_samples = 3
...@@ -839,7 +818,7 @@ class VOCMockData: ...@@ -839,7 +818,7 @@ class VOCMockData:
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def voc(info, root, config): def voc(info, root, config):
trainval = config.split != "test" trainval = config.split != "test"
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval) num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
...@@ -938,13 +917,13 @@ class CelebAMockData: ...@@ -938,13 +917,13 @@ class CelebAMockData:
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def celeba(info, root, _): def celeba(info, root, _):
num_samples_map = CelebAMockData.generate(root) num_samples_map = CelebAMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs} return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable @register_mock
def dtd(info, root, _): def dtd(info, root, _):
data_folder = root / "dtd" data_folder = root / "dtd"
...@@ -992,7 +971,7 @@ def dtd(info, root, _): ...@@ -992,7 +971,7 @@ def dtd(info, root, _):
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def fer2013(info, root, config): def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3 num_samples = 5 if config.split == "train" else 3
...@@ -1017,7 +996,7 @@ def fer2013(info, root, config): ...@@ -1017,7 +996,7 @@ def fer2013(info, root, config):
return num_samples return num_samples
@DATASET_MOCKS.set_from_named_callable @register_mock
def gtsrb(info, root, config): def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3 num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012") classes = ("00000", "00042", "00012")
...@@ -1087,7 +1066,7 @@ def gtsrb(info, root, config): ...@@ -1087,7 +1066,7 @@ def gtsrb(info, root, config):
return num_examples return num_examples
@DATASET_MOCKS.set_from_named_callable @register_mock
def clevr(info, root, config): def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0" data_folder = root / "CLEVR_v1.0"
...@@ -1193,7 +1172,7 @@ class OxfordIIITPetMockData: ...@@ -1193,7 +1172,7 @@ class OxfordIIITPetMockData:
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def oxford_iiit_pet(info, root, config): def oxford_iiit_pet(info, root, config):
num_samples_map = OxfordIIITPetMockData.generate(root) num_samples_map = OxfordIIITPetMockData.generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs} return {config_: num_samples_map[config_.split] for config_ in info._configs}
...@@ -1360,13 +1339,13 @@ class CUB2002010MockData(_CUB200MockData): ...@@ -1360,13 +1339,13 @@ class CUB2002010MockData(_CUB200MockData):
return num_samples_map return num_samples_map
@DATASET_MOCKS.set_from_named_callable @register_mock
def cub200(info, root, config): def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year} return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
@DATASET_MOCKS.set_from_named_callable @register_mock
def svhn(info, root, config): def svhn(info, root, config):
import scipy.io as sio import scipy.io as sio
......
...@@ -110,7 +110,7 @@ class TestCommon: ...@@ -110,7 +110,7 @@ class TestCommon:
) )
}, },
) )
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph): def scan(graph):
for node, sub_graph in graph.items(): for node, sub_graph in graph.items():
...@@ -120,10 +120,7 @@ class TestCommon: ...@@ -120,10 +120,7 @@ class TestCommon:
with dataset_mock.prepare(config): with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
for dp in scan(traverse(dataset)): if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
if type(dp) is annotation_dp_type:
break
else:
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment