Unverified Commit 4d08a673 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Use public API for loading in prototype datasets tests (#5212)

* refactor prototype dataset tests to use public API for loading

* add explanation

* use loop alternative
parent 6512146e
...@@ -10,6 +10,7 @@ import pathlib ...@@ -10,6 +10,7 @@ import pathlib
import pickle import pickle
import random import random
import tempfile import tempfile
import unittest.mock
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict from collections import defaultdict, Counter, UserDict
...@@ -21,7 +22,8 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image ...@@ -21,7 +22,8 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ()) make_scalar = functools.partial(make_tensor, ())
...@@ -49,7 +51,7 @@ class DatasetMock: ...@@ -49,7 +51,7 @@ class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None): def __init__(self, name, mock_data_fn, *, configs=None):
self.dataset = find(name) self.dataset = find(name)
self.root = TEST_HOME / self.dataset.name self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = self._parse_mock_data(mock_data_fn) self.mock_data_fn = mock_data_fn
self.configs = configs or self.info._configs self.configs = configs or self.info._configs
self._cache = {} self._cache = {}
...@@ -61,77 +63,71 @@ class DatasetMock: ...@@ -61,77 +63,71 @@ class DatasetMock:
def name(self): def name(self):
return self.info.name return self.info.name
def _parse_mock_data(self, mock_data_fn): def _parse_mock_data(self, config, mock_infos):
def wrapper(info, root, config): if mock_infos is None:
mock_infos = mock_data_fn(info, root, config) raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`."
)
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)
if mock_infos is None: for config_, mock_info in list(mock_infos.items()):
if config_ in self._cache:
raise pytest.UsageError( raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"integer indicating the number of samples for the current `config`." f"already exists in the cache."
) )
if isinstance(mock_info, int):
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} mock_infos[config_] = dict(num_samples=mock_info)
if datasets.utils.DatasetConfig not in key_types: elif not isinstance(mock_info, dict):
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError( raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If " f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
) )
for config_, mock_info in list(mock_infos.items()): return mock_infos
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
)
return mock_infos
return wrapper
def _load_mock(self, config): def _prepare_resources(self, config):
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
return self._cache[config] return self._cache[config]
self.root.mkdir(exist_ok=True) self.root.mkdir(exist_ok=True)
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items(): mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
mock_resources = [
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name) available_file_names = {path.name for path in self.root.glob("*")}
for resource in self.dataset.resources(config_) for config_, mock_info in mock_infos.items():
] required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
self._cache[config_] = (mock_resources, mock_info) missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config_}, but they were not created by the mock data function."
)
self._cache[config_] = mock_info
return self._cache[config] return self._cache[config]
def load(self, config, *, decoder=DEFAULT_DECODER): @contextlib.contextmanager
try: def prepare(self, config):
self.info.check_dependencies() mock_info = self._prepare_resources(config)
except ModuleNotFoundError as error: with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
pytest.skip(str(error)) yield mock_info
mock_resources, mock_info = self._load_mock(config)
datapipe = self.dataset._make_datapipe(
[resource.load(self.root) for resource in mock_resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
)
return datapipe, mock_info
def config_id(name, config): def config_id(name, config):
...@@ -1000,7 +996,7 @@ def dtd(info, root, _): ...@@ -1000,7 +996,7 @@ def dtd(info, root, _):
def fer2013(info, root, config): def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3 num_samples = 5 if config.split == "train" else 3
path = root / f"{config.split}.txt" path = root / f"{config.split}.csv"
with open(path, "w", newline="") as file: with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else [] field_names = ["emotion"] if config.split == "train" else []
field_names.append("pixels") field_names.append("pixels")
...@@ -1061,7 +1057,7 @@ def clevr(info, root, config): ...@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
file, file,
) )
make_zip(root, f"{data_folder.name}.zip") make_zip(root, f"{data_folder.name}.zip", data_folder)
return {config_: num_samples_map[config_.split] for config_ in info._configs} return {config_: num_samples_map[config_.split] for config_ in info._configs}
...@@ -1121,8 +1117,8 @@ class OxfordIIITPetMockData: ...@@ -1121,8 +1117,8 @@ class OxfordIIITPetMockData:
for path in segmentation_files: for path in segmentation_files:
path.with_name(f".{path.name}").touch() path.with_name(f".{path.name}").touch()
make_tar(root, "images.tar") make_tar(root, "images.tar.gz", compression="gz")
make_tar(root, anns_folder.with_suffix(".tar").name) make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz")
return num_samples_map return num_samples_map
...@@ -1211,7 +1207,7 @@ class CUB2002011MockData(_CUB200MockData): ...@@ -1211,7 +1207,7 @@ class CUB2002011MockData(_CUB200MockData):
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()], size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
) )
make_tar(root, segmentations_folder.with_suffix(".tgz").name) make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz")
@classmethod @classmethod
def generate(cls, root): def generate(cls, root):
......
...@@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs): ...@@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name archive = pathlib.Path(root) / name
if not files_or_dirs: if not files_or_dirs:
dir = archive.with_suffix("") # We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
if dir.exists() and dir.is_dir(): # present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
files_or_dirs = (dir,) file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else: else:
raise ValueError("No file or dir provided.") raise ValueError("No file or dir provided.")
......
...@@ -23,13 +23,16 @@ def test_coverage(): ...@@ -23,13 +23,16 @@ def test_coverage():
class TestCommon: class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config): def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config): def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
try: try:
sample = next(iter(dataset)) sample = next(iter(dataset))
...@@ -44,7 +47,8 @@ class TestCommon: ...@@ -44,7 +47,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config): def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config) with dataset_mock.prepare(config) as mock_info:
dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0 num_samples = 0
for _ in dataset: for _ in dataset:
...@@ -54,7 +58,8 @@ class TestCommon: ...@@ -54,7 +58,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config): def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features: if undecoded_features:
...@@ -65,7 +70,8 @@ class TestCommon: ...@@ -65,7 +70,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config): def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors: if vanilla_tensors:
...@@ -76,7 +82,8 @@ class TestCommon: ...@@ -76,7 +82,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config): def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
next(iter(dataset.map(transforms.Identity()))) next(iter(dataset.map(transforms.Identity())))
...@@ -89,7 +96,8 @@ class TestCommon: ...@@ -89,7 +96,8 @@ class TestCommon:
}, },
) )
def test_traversable(self, dataset_mock, config): def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset) traverse(dataset)
...@@ -108,7 +116,8 @@ class TestCommon: ...@@ -108,7 +116,8 @@ class TestCommon:
yield node yield node
yield from scan(sub_graph) yield from scan(sub_graph)
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
for dp in scan(traverse(dataset)): for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type: if type(dp) is annotation_dp_type:
...@@ -120,7 +129,8 @@ class TestCommon: ...@@ -120,7 +129,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST: class TestQMNIST:
def test_extra_label(self, dataset_mock, config): def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset)) sample = next(iter(dataset))
for key, type in ( for key, type in (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment