"docs/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "040526310ed1b502647510648464d2673de8ad63"
Unverified Commit 732af96b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use torchdata as single source of truth for everthing datapipe (#6068)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent c05ad81b
...@@ -10,9 +10,14 @@ import torch ...@@ -10,9 +10,14 @@ import torch
import torchvision.prototype.transforms.utils import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse_dps
# TODO: replace with torchdata equivalent as soon as it is available
from torch.utils.data.graph_settings import get_all_graph_pipes from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
......
...@@ -3,7 +3,7 @@ import importlib ...@@ -3,7 +3,7 @@ import importlib
import pathlib import pathlib
from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union
from torch.utils.data import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.datasets.utils import verify_str_arg from torchvision.datasets.utils import verify_str_arg
from ._resource import OnlineResource from ._resource import OnlineResource
......
...@@ -104,7 +104,7 @@ class PicklerDataPipe(IterDataPipe): ...@@ -104,7 +104,7 @@ class PicklerDataPipe(IterDataPipe):
yield d yield d
class SharderDataPipe(torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe): class SharderDataPipe(ShardingFilter):
def __init__(self, source_datapipe: IterDataPipe) -> None: def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe) super().__init__(source_datapipe)
self.rank = 0 self.rank = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment