Unverified Commit 93104c16 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test to enforce infinite buffer size for all applicable datapipes (#5707)

* add test to enforce infinite buffer size for all applicable datapipes

* use utility function to extract datapipes

* check for buffer_size attr rather than type

* simplify
parent aa211974
......@@ -7,11 +7,12 @@ import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label
assert_samples_equal = functools.partial(
......@@ -19,6 +20,10 @@ assert_samples_equal = functools.partial(
)
def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp, only_datapipe=True))
@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
......@@ -125,16 +130,12 @@ class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS)
......@@ -148,6 +149,17 @@ class TestCommon:
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
for dp in extract_datapipes(dataset):
if hasattr(dp, "buffer_size"):
# TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment