Unverified Commit 067dc302 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use fixture for dataset root in tests and remove caching (#5271)



* use fixture for dataset root in tests

* fix home dir generation

* remove caching
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 59c723cb
import collections.abc
import contextlib
import csv
import functools
import gzip
......@@ -9,8 +8,6 @@ import lzma
import pathlib
import pickle
import random
import tempfile
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter
......@@ -21,15 +18,12 @@ import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
TEST_HOME = pathlib.Path(tempfile.mkdtemp())
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
......@@ -40,76 +34,48 @@ class DatasetMock:
self.info = self.dataset.info
self.name = self.info.name
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = mock_data_fn
self.configs = self.info._configs
self._cache = {}
def _parse_mock_data(self, config, mock_infos):
if mock_infos is None:
def _parse_mock_info(self, mock_info):
if mock_info is None:
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`."
)
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
for config_, mock_info in mock_infos.items():
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif isinstance(mock_info, int):
mock_info = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
f"The mock data function for dataset '{self.name}' returned a {type(mock_info)}. The returned object "
f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no "
f"additional information is required for specific tests, the number of samples can also be returned as "
f"an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
f"The dictionary returned by the mock data function for dataset '{self.name}' has to contain a "
f"`'num_samples'` entry indicating the number of samples."
)
return mock_infos
return mock_info
def _prepare_resources(self, config):
if config in self._cache:
return self._cache[config]
def prepare(self, home, config):
root = home / self.name
root.mkdir(exist_ok=True)
self.root.mkdir(exist_ok=True)
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
available_file_names = {path.name for path in self.root.glob("*")}
for config_, mock_info in mock_infos.items():
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
available_file_names = {path.name for path in root.glob("*")}
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config_}, but they were not created by the mock data function."
f"for {config}, but they were not created by the mock data function."
)
self._cache[config_] = mock_info
return self._cache[config]
@contextlib.contextmanager
def prepare(self, config):
mock_info = self._prepare_resources(config)
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
yield mock_info
return mock_info
def config_id(name, config):
......@@ -254,32 +220,30 @@ DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist"
@register_mock
def emnist(info, root, _):
def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict(
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
)
mock_infos = {}
num_samples_map = {}
file_names = set()
for config in info._configs:
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}"
for config_ in info._configs:
prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file})
mock_infos[config] = dict(
num_samples=MNISTMockData.generate(
num_samples_map[config_] = MNISTMockData.generate(
root,
num_categories=num_categories[config.image_set],
num_categories=num_categories[config_.image_set],
images_file=images_file,
labels_file=labels_file,
)
)
make_zip(root, "emnist-gzip.zip", *file_names)
return mock_infos
return num_samples_map[config]
@register_mock
......@@ -290,25 +254,23 @@ def qmnist(info, root, config):
prefix = "qmnist-train"
suffix = ".gz"
compressor = gzip.open
mock_infos = num_samples
elif config.split.startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty.
num_samples_gen = 10001
num_samples = {
"test": num_samples_gen,
"test10k": min(num_samples_gen, 10_000),
"test50k": num_samples_gen - 10_000,
}[config.split]
prefix = "qmnist-test"
suffix = ".gz"
compressor = gzip.open
mock_infos = {
info.make_config(split="test"): num_samples_gen,
info.make_config(split="test10k"): min(num_samples_gen, 10_000),
info.make_config(split="test50k"): num_samples_gen - 10_000,
}
else: # config.split == "nist"
num_samples = num_samples_gen = num_categories + 3
prefix = "xnist"
suffix = ".xz"
compressor = lzma.open
mock_infos = num_samples
MNISTMockData.generate(
root,
......@@ -320,7 +282,7 @@ def qmnist(info, root, config):
label_dtype=torch.int32,
compressor=compressor,
)
return mock_infos
return num_samples
class CIFARMockData:
......@@ -624,12 +586,7 @@ class CocoMockData:
@register_mock
def coco(info, root, config):
return dict(
zip(
[config_ for config_ in info._configs if config_.year == config.year],
itertools.repeat(CocoMockData.generate(root, year=config.year, num_samples=5)),
)
)
return CocoMockData.generate(root, year=config.year, num_samples=5)
class SBDMockData:
......@@ -702,9 +659,8 @@ class SBDMockData:
@register_mock
def sbd(info, root, _):
num_samples_map = SBDMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
def sbd(info, root, config):
return SBDMockData.generate(root)[config.split]
@register_mock
......@@ -821,12 +777,7 @@ class VOCMockData:
@register_mock
def voc(info, root, config):
trainval = config.split != "test"
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
return {
config_: num_samples_map[config_.split]
for config_ in info._configs
if config_.year == config.year and ((config_.split == "test") ^ trainval)
}
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
class CelebAMockData:
......@@ -918,13 +869,12 @@ class CelebAMockData:
@register_mock
def celeba(info, root, _):
num_samples_map = CelebAMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
def celeba(info, root, config):
return CelebAMockData.generate(root)[config.split]
@register_mock
def dtd(info, root, _):
def dtd(info, root, config):
data_folder = root / "dtd"
num_images_per_class = 3
......@@ -968,7 +918,7 @@ def dtd(info, root, _):
make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz")
return num_samples_map
return num_samples_map[config]
@register_mock
......@@ -1108,7 +1058,7 @@ def clevr(info, root, config):
make_zip(root, f"{data_folder.name}.zip", data_folder)
return {config_: num_samples_map[config_.split] for config_ in info._configs}
return num_samples_map[config.split]
class OxfordIIITPetMockData:
......@@ -1174,8 +1124,7 @@ class OxfordIIITPetMockData:
@register_mock
def oxford_iiit_pet(info, root, config):
num_samples_map = OxfordIIITPetMockData.generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs}
return OxfordIIITPetMockData.generate(root)[config.split]
class _CUB200MockData:
......@@ -1342,7 +1291,7 @@ class CUB2002010MockData(_CUB200MockData):
@register_mock
def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
return num_samples_map[config.split]
@register_mock
......
......@@ -11,6 +11,12 @@ from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str
@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
yield tmp_path
def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
if untested_datasets:
......@@ -23,16 +29,18 @@ def test_coverage():
class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_sample(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
try:
......@@ -47,8 +55,9 @@ class TestCommon:
raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
with dataset_mock.prepare(config) as mock_info:
def test_num_samples(self, test_home, dataset_mock, config):
mock_info = dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0
......@@ -58,8 +67,9 @@ class TestCommon:
assert num_samples == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_decoding(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
......@@ -70,8 +80,9 @@ class TestCommon:
)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
......@@ -82,8 +93,9 @@ class TestCommon:
)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
next(iter(dataset.map(transforms.Identity())))
......@@ -96,8 +108,9 @@ class TestCommon:
)
},
)
def test_traversable(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_traversable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset)
......@@ -111,13 +124,14 @@ class TestCommon:
},
)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)
with dataset_mock.prepare(config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
......@@ -126,8 +140,9 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
with dataset_mock.prepare(config):
def test_extra_label(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset))
......@@ -145,13 +160,14 @@ class TestQMNIST:
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB:
def test_label_matches_path(self, dataset_mock, config):
def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config.split != "train":
return
with dataset_mock.prepare(config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
for sample in dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment