"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e46ec5f88fec23870538df782258c59271b010fd"
Unverified Commit 90a729a1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rely on patched datasets home rather than passing it around (#5998)

* rely on patched datasets home rather than passing it around

* add comment
parent 18b39e36
...@@ -62,8 +62,10 @@ class DatasetMock: ...@@ -62,8 +62,10 @@ class DatasetMock:
return mock_info return mock_info
def prepare(self, home, config): def prepare(self, config):
root = home / self.name # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in
# test/test_prototype_builtin_datasets.py
root = pathlib.Path(datasets.home()) / self.name
root.mkdir(exist_ok=True) root.mkdir(exist_ok=True)
mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
......
...@@ -25,9 +25,10 @@ def extract_datapipes(dp): ...@@ -25,9 +25,10 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp, only_datapipe=True)) return get_all_graph_pipes(traverse(dp, only_datapipe=True))
@pytest.fixture @pytest.fixture(autouse=True)
def test_home(mocker, tmp_path): def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path))
yield tmp_path yield tmp_path
...@@ -54,8 +55,8 @@ class TestCommon: ...@@ -54,8 +55,8 @@ class TestCommon:
raise AssertionError("Info should be a dictionary with string keys.") raise AssertionError("Info should be a dictionary with string keys.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config): def test_smoke(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -63,8 +64,8 @@ class TestCommon: ...@@ -63,8 +64,8 @@ class TestCommon:
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config): def test_sample(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -82,16 +83,16 @@ class TestCommon: ...@@ -82,16 +83,16 @@ class TestCommon:
raise AssertionError("Sample dictionary is empty.") raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, test_home, dataset_mock, config): def test_num_samples(self, dataset_mock, config):
mock_info = dataset_mock.prepare(test_home, config) mock_info = dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
assert len(list(dataset)) == mock_info["num_samples"] assert len(list(dataset)) == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config): def test_no_vanilla_tensors(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -103,8 +104,8 @@ class TestCommon: ...@@ -103,8 +104,8 @@ class TestCommon:
) )
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config): def test_transformable(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -112,15 +113,15 @@ class TestCommon: ...@@ -112,15 +113,15 @@ class TestCommon:
@pytest.mark.parametrize("only_datapipe", [False, True]) @pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, test_home, dataset_mock, config, only_datapipe): def test_traversable(self, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
traverse(dataset, only_datapipe=only_datapipe) traverse(dataset, only_datapipe=only_datapipe)
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config): def test_serializable(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
pickle.dumps(dataset) pickle.dumps(dataset)
...@@ -133,8 +134,8 @@ class TestCommon: ...@@ -133,8 +134,8 @@ class TestCommon:
@pytest.mark.parametrize("num_workers", [0, 1]) @pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers): def test_data_loader(self, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
dl = DataLoader( dl = DataLoader(
...@@ -151,17 +152,17 @@ class TestCommon: ...@@ -151,17 +152,17 @@ class TestCommon:
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, test_home, dataset_mock, config): def test_save_load(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset)) sample = next(iter(dataset))
...@@ -171,8 +172,8 @@ class TestCommon: ...@@ -171,8 +172,8 @@ class TestCommon:
assert_samples_equal(torch.load(buffer), sample) assert_samples_equal(torch.load(buffer), sample)
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, test_home, dataset_mock, config): def test_infinite_buffer_size(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
for dp in extract_datapipes(dataset): for dp in extract_datapipes(dataset):
...@@ -182,8 +183,8 @@ class TestCommon: ...@@ -182,8 +183,8 @@ class TestCommon:
assert dp.buffer_size == INFINITE_BUFFER_SIZE assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config): def test_has_length(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
assert len(dataset) > 0 assert len(dataset) > 0
...@@ -191,8 +192,8 @@ class TestCommon: ...@@ -191,8 +192,8 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST: class TestQMNIST:
def test_extra_label(self, test_home, dataset_mock, config): def test_extra_label(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -211,13 +212,13 @@ class TestQMNIST: ...@@ -211,13 +212,13 @@ class TestQMNIST:
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) @parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB: class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config): def test_label_matches_path(self, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same # This test makes sure that they're both the same
if config["split"] != "train": if config["split"] != "train":
return return
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
...@@ -228,8 +229,8 @@ class TestGTSRB: ...@@ -228,8 +229,8 @@ class TestGTSRB:
@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) @parametrize_dataset_mocks(DATASET_MOCKS["usps"])
class TestUSPS: class TestUSPS:
def test_sample_content(self, test_home, dataset_mock, config): def test_sample_content(self, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(config)
dataset = datasets.load(dataset_mock.name, **config) dataset = datasets.load(dataset_mock.name, **config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment