Unverified Commit 90a729a1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rely on patched datasets home rather than passing it around (#5998)

* rely on patched datasets home rather than passing it around

* add comment
parent 18b39e36
......@@ -62,8 +62,10 @@ class DatasetMock:
return mock_info
def prepare(self, home, config):
root = home / self.name
def prepare(self, config):
# `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in
# test/test_prototype_builtin_datasets.py
root = pathlib.Path(datasets.home()) / self.name
root.mkdir(exist_ok=True)
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
......
......@@ -25,9 +25,10 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp, only_datapipe=True))
@pytest.fixture
@pytest.fixture(autouse=True)
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path))
yield tmp_path
......@@ -54,8 +55,8 @@ class TestCommon:
raise AssertionError("Info should be a dictionary with string keys.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_smoke(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -63,8 +64,8 @@ class TestCommon:
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_sample(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -82,16 +83,16 @@ class TestCommon:
raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, test_home, dataset_mock, config):
mock_info = dataset_mock.prepare(test_home, config)
def test_num_samples(self, dataset_mock, config):
mock_info = dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
assert len(list(dataset)) == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_no_vanilla_tensors(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -103,8 +104,8 @@ class TestCommon:
)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_transformable(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -112,15 +113,15 @@ class TestCommon:
@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config)
def test_traversable(self, dataset_mock, config, only_datapipe):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset, only_datapipe=only_datapipe)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_serializable(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
pickle.dumps(dataset)
......@@ -133,8 +134,8 @@ class TestCommon:
@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config)
def test_data_loader(self, dataset_mock, config, num_workers):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
dl = DataLoader(
......@@ -151,17 +152,17 @@ class TestCommon:
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
dataset_mock.prepare(test_home, config)
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_save_load(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset))
......@@ -171,8 +172,8 @@ class TestCommon:
assert_samples_equal(torch.load(buffer), sample)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_infinite_buffer_size(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
for dp in extract_datapipes(dataset):
......@@ -182,8 +183,8 @@ class TestCommon:
assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_has_length(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
assert len(dataset) > 0
......@@ -191,8 +192,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_extra_label(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -211,13 +212,13 @@ class TestQMNIST:
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config):
def test_label_matches_path(self, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config["split"] != "train":
return
dataset_mock.prepare(test_home, config)
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......@@ -228,8 +229,8 @@ class TestGTSRB:
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
class TestUSPS:
def test_sample_content(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
def test_sample_content(self, dataset_mock, config):
dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment