"src/vscode:/vscode.git/clone" did not exist on "750bd7920622b3fe538d20035d3f03855c5d6621"
Unverified Commit 067dc302 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use fixture for dataset root in tests and remove caching (#5271)



* use fixture for dataset root in tests

* fix home dir generation

* remove caching
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 59c723cb
import collections.abc import collections.abc
import contextlib
import csv import csv
import functools import functools
import gzip import gzip
...@@ -9,8 +8,6 @@ import lzma ...@@ -9,8 +8,6 @@ import lzma
import pathlib import pathlib
import pickle import pickle
import random import random
import tempfile
import unittest.mock
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict, Counter from collections import defaultdict, Counter
...@@ -21,15 +18,12 @@ import torch ...@@ -21,15 +18,12 @@ import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ()) make_scalar = functools.partial(make_tensor, ())
TEST_HOME = pathlib.Path(tempfile.mkdtemp())
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"] __all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
...@@ -40,76 +34,48 @@ class DatasetMock: ...@@ -40,76 +34,48 @@ class DatasetMock:
self.info = self.dataset.info self.info = self.dataset.info
self.name = self.info.name self.name = self.info.name
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = mock_data_fn self.mock_data_fn = mock_data_fn
self.configs = self.info._configs self.configs = self.info._configs
self._cache = {}
def _parse_mock_data(self, config, mock_infos): def _parse_mock_info(self, mock_info):
if mock_infos is None: if mock_info is None:
raise pytest.UsageError( raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`." f"integer indicating the number of samples for the current `config`."
) )
elif isinstance(mock_info, int):
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} mock_info = dict(num_samples=mock_info)
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
for config_, mock_info in mock_infos.items():
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict): elif not isinstance(mock_info, dict):
raise pytest.UsageError( raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` " f"The mock data function for dataset '{self.name}' returned a {type(mock_info)}. The returned object "
f"{config_}. The returned object should be a dictionary containing at least the number of " f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no "
f"samples for the key `'num_samples'`. If no additional information is required for specific " f"additional information is required for specific tests, the number of samples can also be returned as "
f"tests, the number of samples can also be returned as an integer." f"an integer."
) )
elif "num_samples" not in mock_info: elif "num_samples" not in mock_info:
raise pytest.UsageError( raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config " f"The dictionary returned by the mock data function for dataset '{self.name}' has to contain a "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples." f"`'num_samples'` entry indicating the number of samples."
) )
return mock_infos return mock_info
def _prepare_resources(self, config): def prepare(self, home, config):
if config in self._cache: root = home / self.name
return self._cache[config] root.mkdir(exist_ok=True)
self.root.mkdir(exist_ok=True) mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
available_file_names = {path.name for path in self.root.glob("*")} available_file_names = {path.name for path in root.glob("*")}
for config_, mock_info in mock_infos.items(): required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
missing_file_names = required_file_names - available_file_names missing_file_names = required_file_names - available_file_names
if missing_file_names: if missing_file_names:
raise pytest.UsageError( raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config_}, but they were not created by the mock data function." f"for {config}, but they were not created by the mock data function."
) )
self._cache[config_] = mock_info return mock_info
return self._cache[config]
@contextlib.contextmanager
def prepare(self, config):
mock_info = self._prepare_resources(config)
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
yield mock_info
def config_id(name, config): def config_id(name, config):
...@@ -254,32 +220,30 @@ DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist" ...@@ -254,32 +220,30 @@ DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist"
@register_mock @register_mock
def emnist(info, root, _): def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense # The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there. # labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict( num_categories = defaultdict(
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")} lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
) )
mock_infos = {} num_samples_map = {}
file_names = set() file_names = set()
for config in info._configs: for config_ in info._configs:
prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}"
images_file = f"{prefix}-images-idx3-ubyte.gz" images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file}) file_names.update({images_file, labels_file})
mock_infos[config] = dict( num_samples_map[config_] = MNISTMockData.generate(
num_samples=MNISTMockData.generate(
root, root,
num_categories=num_categories[config.image_set], num_categories=num_categories[config_.image_set],
images_file=images_file, images_file=images_file,
labels_file=labels_file, labels_file=labels_file,
) )
)
make_zip(root, "emnist-gzip.zip", *file_names) make_zip(root, "emnist-gzip.zip", *file_names)
return mock_infos return num_samples_map[config]
@register_mock @register_mock
...@@ -290,25 +254,23 @@ def qmnist(info, root, config): ...@@ -290,25 +254,23 @@ def qmnist(info, root, config):
prefix = "qmnist-train" prefix = "qmnist-train"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
mock_infos = num_samples
elif config.split.startswith("test"): elif config.split.startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty. # more than 10000 images for the dataset to not be empty.
num_samples_gen = 10001 num_samples_gen = 10001
num_samples = {
"test": num_samples_gen,
"test10k": min(num_samples_gen, 10_000),
"test50k": num_samples_gen - 10_000,
}[config.split]
prefix = "qmnist-test" prefix = "qmnist-test"
suffix = ".gz" suffix = ".gz"
compressor = gzip.open compressor = gzip.open
mock_infos = {
info.make_config(split="test"): num_samples_gen,
info.make_config(split="test10k"): min(num_samples_gen, 10_000),
info.make_config(split="test50k"): num_samples_gen - 10_000,
}
else: # config.split == "nist" else: # config.split == "nist"
num_samples = num_samples_gen = num_categories + 3 num_samples = num_samples_gen = num_categories + 3
prefix = "xnist" prefix = "xnist"
suffix = ".xz" suffix = ".xz"
compressor = lzma.open compressor = lzma.open
mock_infos = num_samples
MNISTMockData.generate( MNISTMockData.generate(
root, root,
...@@ -320,7 +282,7 @@ def qmnist(info, root, config): ...@@ -320,7 +282,7 @@ def qmnist(info, root, config):
label_dtype=torch.int32, label_dtype=torch.int32,
compressor=compressor, compressor=compressor,
) )
return mock_infos return num_samples
class CIFARMockData: class CIFARMockData:
...@@ -624,12 +586,7 @@ class CocoMockData: ...@@ -624,12 +586,7 @@ class CocoMockData:
@register_mock @register_mock
def coco(info, root, config): def coco(info, root, config):
return dict( return CocoMockData.generate(root, year=config.year, num_samples=5)
zip(
[config_ for config_ in info._configs if config_.year == config.year],
itertools.repeat(CocoMockData.generate(root, year=config.year, num_samples=5)),
)
)
class SBDMockData: class SBDMockData:
...@@ -702,9 +659,8 @@ class SBDMockData: ...@@ -702,9 +659,8 @@ class SBDMockData:
@register_mock @register_mock
def sbd(info, root, _): def sbd(info, root, config):
num_samples_map = SBDMockData.generate(root) return SBDMockData.generate(root)[config.split]
return {config: num_samples_map[config.split] for config in info._configs}
@register_mock @register_mock
...@@ -821,12 +777,7 @@ class VOCMockData: ...@@ -821,12 +777,7 @@ class VOCMockData:
@register_mock @register_mock
def voc(info, root, config): def voc(info, root, config):
trainval = config.split != "test" trainval = config.split != "test"
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval) return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
return {
config_: num_samples_map[config_.split]
for config_ in info._configs
if config_.year == config.year and ((config_.split == "test") ^ trainval)
}
class CelebAMockData: class CelebAMockData:
...@@ -918,13 +869,12 @@ class CelebAMockData: ...@@ -918,13 +869,12 @@ class CelebAMockData:
@register_mock @register_mock
def celeba(info, root, _): def celeba(info, root, config):
num_samples_map = CelebAMockData.generate(root) return CelebAMockData.generate(root)[config.split]
return {config: num_samples_map[config.split] for config in info._configs}
@register_mock @register_mock
def dtd(info, root, _): def dtd(info, root, config):
data_folder = root / "dtd" data_folder = root / "dtd"
num_images_per_class = 3 num_images_per_class = 3
...@@ -968,7 +918,7 @@ def dtd(info, root, _): ...@@ -968,7 +918,7 @@ def dtd(info, root, _):
make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz")
return num_samples_map return num_samples_map[config]
@register_mock @register_mock
...@@ -1108,7 +1058,7 @@ def clevr(info, root, config): ...@@ -1108,7 +1058,7 @@ def clevr(info, root, config):
make_zip(root, f"{data_folder.name}.zip", data_folder) make_zip(root, f"{data_folder.name}.zip", data_folder)
return {config_: num_samples_map[config_.split] for config_ in info._configs} return num_samples_map[config.split]
class OxfordIIITPetMockData: class OxfordIIITPetMockData:
...@@ -1174,8 +1124,7 @@ class OxfordIIITPetMockData: ...@@ -1174,8 +1124,7 @@ class OxfordIIITPetMockData:
@register_mock @register_mock
def oxford_iiit_pet(info, root, config): def oxford_iiit_pet(info, root, config):
num_samples_map = OxfordIIITPetMockData.generate(root) return OxfordIIITPetMockData.generate(root)[config.split]
return {config_: num_samples_map[config_.split] for config_ in info._configs}
class _CUB200MockData: class _CUB200MockData:
...@@ -1342,7 +1291,7 @@ class CUB2002010MockData(_CUB200MockData): ...@@ -1342,7 +1291,7 @@ class CUB2002010MockData(_CUB200MockData):
@register_mock @register_mock
def cub200(info, root, config): def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year} return num_samples_map[config.split]
@register_mock @register_mock
......
...@@ -11,6 +11,12 @@ from torchvision.prototype import transforms, datasets ...@@ -11,6 +11,12 @@ from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
yield tmp_path
def test_coverage(): def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys() untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
if untested_datasets: if untested_datasets:
...@@ -23,16 +29,18 @@ def test_coverage(): ...@@ -23,16 +29,18 @@ def test_coverage():
class TestCommon: class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config): def test_smoke(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config): def test_sample(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
try: try:
...@@ -47,8 +55,9 @@ class TestCommon: ...@@ -47,8 +55,9 @@ class TestCommon:
raise AssertionError("Sample dictionary is empty.") raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config): def test_num_samples(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config) as mock_info: mock_info = dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0 num_samples = 0
...@@ -58,8 +67,9 @@ class TestCommon: ...@@ -58,8 +67,9 @@ class TestCommon:
assert num_samples == mock_info["num_samples"] assert num_samples == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config): def test_decoding(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
...@@ -70,8 +80,9 @@ class TestCommon: ...@@ -70,8 +80,9 @@ class TestCommon:
) )
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config): def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
...@@ -82,8 +93,9 @@ class TestCommon: ...@@ -82,8 +93,9 @@ class TestCommon:
) )
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config): def test_transformable(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
next(iter(dataset.map(transforms.Identity()))) next(iter(dataset.map(transforms.Identity())))
...@@ -96,8 +108,9 @@ class TestCommon: ...@@ -96,8 +108,9 @@ class TestCommon:
) )
}, },
) )
def test_traversable(self, dataset_mock, config): def test_traversable(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset) traverse(dataset)
...@@ -111,13 +124,14 @@ class TestCommon: ...@@ -111,13 +124,14 @@ class TestCommon:
}, },
) )
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
def scan(graph): def scan(graph):
for node, sub_graph in graph.items(): for node, sub_graph in graph.items():
yield node yield node
yield from scan(sub_graph) yield from scan(sub_graph)
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
...@@ -126,8 +140,9 @@ class TestCommon: ...@@ -126,8 +140,9 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST: class TestQMNIST:
def test_extra_label(self, dataset_mock, config): def test_extra_label(self, test_home, dataset_mock, config):
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset)) sample = next(iter(dataset))
...@@ -145,13 +160,14 @@ class TestQMNIST: ...@@ -145,13 +160,14 @@ class TestQMNIST:
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) @parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB: class TestGTSRB:
def test_label_matches_path(self, dataset_mock, config): def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same # This test makes sure that they're both the same
if config.split != "train": if config.split != "train":
return return
with dataset_mock.prepare(config): dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
for sample in dataset: for sample in dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment