Unverified Commit 0fd4736e authored by Kevin Tse's avatar Kevin Tse Committed by GitHub
Browse files

Replace `torch.utils.data.graph.traverse` with `traverse_dps` (#6657)

* Replace torch.utils.data.graph.traverse with traverse_dps

[ghstack-poisoned]

* Update on "Replace `torch.utils.data.graph.traverse` with `traverse_dps`"


CI is expected to fail for now. This should be merged only after https://github.com/pytorch/pytorch/pull/85667

 has been merged into nightly and internal.



[ghstack-poisoned]
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 29b0831c
......@@ -8,7 +8,7 @@ import torch
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse
from torch.utils.data.graph import traverse_dps
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchvision._utils import sequence_to_str
......@@ -22,7 +22,7 @@ assert_samples_equal = functools.partial(
def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp))
return get_all_graph_pipes(traverse_dps(dp))
@pytest.fixture(autouse=True)
......@@ -42,12 +42,6 @@ def test_coverage():
)
# FIXME: This decorator only applies to `test_data_loader`, but we can't put it there because the class-wide fail on
# warnings would take higher priority.
# Although we are not using `traverse(..., only_datapipe=...)` in `test_data_loader` directly, the `DataLoader` does.
# This will emit the warning, which in turn will fail the test if we don't ignore it. There is a push to fix this in
# https://github.com/pytorch/pytorch/pull/85667.
@pytest.mark.filterwarnings("ignore:`only_datapipe` is deprecated:FutureWarning")
@pytest.mark.filterwarnings("error")
class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
......@@ -111,7 +105,7 @@ class TestCommon:
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
traverse(dataset)
traverse_dps(dataset)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, dataset_mock, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment