Unverified Commit 4d08a673 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Use public API for loading in prototype datasets tests (#5212)

* refactor prototype dataset tests to use public API for loading

* add explanation

* use loop alternative
parent 6512146e
......@@ -10,6 +10,7 @@ import pathlib
import pickle
import random
import tempfile
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict
......@@ -21,7 +22,8 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
......@@ -49,7 +51,7 @@ class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None):
self.dataset = find(name)
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
self.mock_data_fn = mock_data_fn
self.configs = configs or self.info._configs
self._cache = {}
......@@ -61,10 +63,7 @@ class DatasetMock:
def name(self):
return self.info.name
def _parse_mock_data(self, mock_data_fn):
def wrapper(info, root, config):
mock_infos = mock_data_fn(info, root, config)
def _parse_mock_data(self, config, mock_infos):
if mock_infos is None:
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
......@@ -103,35 +102,32 @@ class DatasetMock:
return mock_infos
return wrapper
def _load_mock(self, config):
def _prepare_resources(self, config):
with contextlib.suppress(KeyError):
return self._cache[config]
self.root.mkdir(exist_ok=True)
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
mock_resources = [
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
for resource in self.dataset.resources(config_)
]
self._cache[config_] = (mock_resources, mock_info)
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
available_file_names = {path.name for path in self.root.glob("*")}
for config_, mock_info in mock_infos.items():
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config_}, but they were not created by the mock data function."
)
self._cache[config_] = mock_info
return self._cache[config]
def load(self, config, *, decoder=DEFAULT_DECODER):
try:
self.info.check_dependencies()
except ModuleNotFoundError as error:
pytest.skip(str(error))
mock_resources, mock_info = self._load_mock(config)
datapipe = self.dataset._make_datapipe(
[resource.load(self.root) for resource in mock_resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
)
return datapipe, mock_info
@contextlib.contextmanager
def prepare(self, config):
mock_info = self._prepare_resources(config)
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
yield mock_info
def config_id(name, config):
......@@ -1000,7 +996,7 @@ def dtd(info, root, _):
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3
path = root / f"{config.split}.txt"
path = root / f"{config.split}.csv"
with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else []
field_names.append("pixels")
......@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
file,
)
make_zip(root, f"{data_folder.name}.zip")
make_zip(root, f"{data_folder.name}.zip", data_folder)
return {config_: num_samples_map[config_.split] for config_ in info._configs}
......@@ -1121,8 +1117,8 @@ class OxfordIIITPetMockData:
for path in segmentation_files:
path.with_name(f".{path.name}").touch()
make_tar(root, "images.tar")
make_tar(root, anns_folder.with_suffix(".tar").name)
make_tar(root, "images.tar.gz", compression="gz")
make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz")
return num_samples_map
......@@ -1211,7 +1207,7 @@ class CUB2002011MockData(_CUB200MockData):
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
)
make_tar(root, segmentations_folder.with_suffix(".tgz").name)
make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz")
@classmethod
def generate(cls, root):
......
......@@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
dir = archive.with_suffix("")
if dir.exists() and dir.is_dir():
files_or_dirs = (dir,)
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else:
raise ValueError("No file or dir provided.")
......
......@@ -23,13 +23,16 @@ def test_coverage():
class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
try:
sample = next(iter(dataset))
......@@ -44,7 +47,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)
with dataset_mock.prepare(config) as mock_info:
dataset = datasets.load(dataset_mock.name, **config)
num_samples = 0
for _ in dataset:
......@@ -54,7 +58,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
......@@ -65,7 +70,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
......@@ -76,7 +82,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
next(iter(dataset.map(transforms.Identity())))
......@@ -89,7 +96,8 @@ class TestCommon:
},
)
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset)
......@@ -108,7 +116,8 @@ class TestCommon:
yield node
yield from scan(sub_graph)
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
......@@ -120,7 +129,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset))
for key, type in (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment