Unverified Commit 0b02d420 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add sharding filters in prototype datasets to help the data loader (#5093)

* add sharding filters in prototype datasets to help the data loader

* hint rather than apply

* fix sharding for folders

* fix import

* fix coco

* appease mypy
parent f5dca445
...@@ -3,8 +3,9 @@ import io ...@@ -3,8 +3,9 @@ import io
import builtin_dataset_mocks import builtin_dataset_mocks
import pytest import pytest
import torch import torch
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import datasets, transforms from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
...@@ -105,6 +106,20 @@ class TestCommon: ...@@ -105,6 +106,20 @@ class TestCommon:
def test_traversable(self, dataset, mock_info): def test_traversable(self, dataset, mock_info):
traverse(dataset) traverse(dataset)
@dataset_parametrization()
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset, mock_info, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)
for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
break
else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
class TestQMNIST: class TestQMNIST:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -20,7 +20,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -20,7 +20,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
from torchvision.prototype.features import Label, BoundingBox, Feature from torchvision.prototype.features import Label, BoundingBox, Feature
...@@ -120,6 +120,7 @@ class Caltech101(Dataset): ...@@ -120,6 +120,7 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = Filter(anns_dp, self._is_ann) anns_dp = Filter(anns_dp, self._is_ann)
...@@ -183,6 +184,7 @@ class Caltech256(Dataset): ...@@ -183,6 +184,7 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
...@@ -19,7 +19,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -19,7 +19,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
...@@ -151,6 +151,7 @@ class CelebA(Dataset): ...@@ -151,6 +151,7 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = Zipper( anns_dp = Zipper(
......
...@@ -26,6 +26,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,6 +26,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
image_buffer_from_array, image_buffer_from_array,
path_comparator, path_comparator,
hint_sharding,
) )
from torchvision.prototype.features import Label, Image from torchvision.prototype.features import Label, Image
...@@ -87,6 +88,7 @@ class _CifarBase(Dataset): ...@@ -87,6 +88,7 @@ class _CifarBase(Dataset):
dp = Filter(dp, functools.partial(self._is_data_file, config=config)) dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
......
...@@ -30,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -30,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import (
BUILTIN_DIR, BUILTIN_DIR,
getitem, getitem,
path_accessor, path_accessor,
hint_sharding,
) )
from torchvision.prototype.features import BoundingBox, Label, Feature from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
...@@ -180,7 +181,8 @@ class Coco(Dataset): ...@@ -180,7 +181,8 @@ class Coco(Dataset):
images_dp, meta_dp = resource_dps images_dp, meta_dp = resource_dps
if config.annotations is None: if config.annotations is None:
dp = Shuffler(images_dp) dp = hint_sharding(images_dp)
dp = Shuffler(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = Filter( meta_dp = Filter(
...@@ -190,7 +192,7 @@ class Coco(Dataset): ...@@ -190,7 +192,7 @@ class Coco(Dataset):
) )
meta_dp = JsonParser(meta_dp) meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = Mapper(meta_dp, getitem(1))
meta_dp = MappingIterator(meta_dp) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
images_meta_dp, anns_meta_dp = Demultiplexer( images_meta_dp, anns_meta_dp = Demultiplexer(
meta_dp, meta_dp,
2, 2,
...@@ -201,11 +203,12 @@ class Coco(Dataset): ...@@ -201,11 +203,12 @@ class Coco(Dataset):
images_meta_dp = Mapper(images_meta_dp, getitem(1)) images_meta_dp = Mapper(images_meta_dp, getitem(1))
images_meta_dp = UnBatcher(images_meta_dp) images_meta_dp = UnBatcher(images_meta_dp)
images_meta_dp = Shuffler(images_meta_dp)
anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp) anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = Shuffler(anns_meta_dp)
anns_dp = IterKeyZipper( anns_dp = IterKeyZipper(
anns_meta_dp, anns_meta_dp,
......
...@@ -20,6 +20,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -20,6 +20,7 @@ from torchvision.prototype.datasets.utils._internal import (
Enumerator, Enumerator,
getitem, getitem,
read_mat, read_mat,
hint_sharding,
) )
from torchvision.prototype.features import Label from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
...@@ -139,6 +140,7 @@ class ImageNet(Dataset): ...@@ -139,6 +140,7 @@ class ImageNet(Dataset):
if config.split == "train": if config.split == "train":
# the train archive is a tar of tars # the train archive is a tar of tars
dp = TarArchiveReader(images_dp) dp = TarArchiveReader(images_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data) dp = Mapper(dp, self._collate_train_data)
elif config.split == "val": elif config.split == "val":
...@@ -146,6 +148,7 @@ class ImageNet(Dataset): ...@@ -146,6 +148,7 @@ class ImageNet(Dataset):
devkit_dp = LineReader(devkit_dp, return_path=False) devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int) devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1) devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE) devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = IterKeyZipper( dp = IterKeyZipper(
...@@ -157,7 +160,8 @@ class ImageNet(Dataset): ...@@ -157,7 +160,8 @@ class ImageNet(Dataset):
) )
dp = Mapper(dp, self._collate_val_data) dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test" else: # config.split == "test"
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_sharding(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_test_data) dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
...@@ -28,10 +28,10 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -28,10 +28,10 @@ from torchvision.prototype.datasets.utils._internal import (
Decompressor, Decompressor,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
fromfile, fromfile,
hint_sharding,
) )
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, Label
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul) prod = functools.partial(functools.reduce, operator.mul)
...@@ -134,6 +134,7 @@ class _MNISTBase(Dataset): ...@@ -134,6 +134,7 @@ class _MNISTBase(Dataset):
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
......
...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
path_accessor, path_accessor,
path_comparator, path_comparator,
hint_sharding,
) )
...@@ -139,6 +140,7 @@ class SBD(Dataset): ...@@ -139,6 +140,7 @@ class SBD(Dataset):
if config.split == "train_noval": if config.split == "train_noval":
split_dp = extra_split_dp split_dp = extra_split_dp
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp) split_dp = Shuffler(split_dp)
dp = split_dp dp = split_dp
......
...@@ -17,7 +17,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -17,7 +17,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding
class SEMEION(Dataset): class SEMEION(Dataset):
...@@ -64,6 +64,7 @@ class SEMEION(Dataset): ...@@ -64,6 +64,7 @@ class SEMEION(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp return dp
...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_comparator, path_comparator,
hint_sharding,
) )
HERE = pathlib.Path(__file__).parent HERE = pathlib.Path(__file__).parent
...@@ -129,6 +130,7 @@ class VOC(Dataset): ...@@ -129,6 +130,7 @@ class VOC(Dataset):
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task])) split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE) split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = split_dp dp = split_dp
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding
__all__ = ["from_data_folder", "from_image_folder"] __all__ = ["from_data_folder", "from_image_folder"]
...@@ -51,6 +51,7 @@ def from_data_folder( ...@@ -51,6 +51,7 @@ def from_data_folder(
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks) dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp) dp = FileLoader(dp)
return ( return (
......
...@@ -186,7 +186,7 @@ class Dataset(abc.ABC): ...@@ -186,7 +186,7 @@ class Dataset(abc.ABC):
if use_sharded_dataset() and self.supports_sharded(): if use_sharded_dataset() and self.supports_sharded():
root = os.path.join(root, *config.values()) root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config] dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size) return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return]
self.info.check_dependencies() self.info.check_dependencies()
resource_dps = [ resource_dps = [
......
...@@ -30,8 +30,7 @@ import PIL.Image ...@@ -30,8 +30,7 @@ import PIL.Image
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.utils.data import torch.utils.data
from torch.utils.data import IterDataPipe from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
...@@ -49,6 +48,7 @@ __all__ = [ ...@@ -49,6 +48,7 @@ __all__ = [
"Decompressor", "Decompressor",
"fromfile", "fromfile",
"read_flo", "read_flo",
"hint_sharding",
] ]
K = TypeVar("K") K = TypeVar("K")
...@@ -96,7 +96,7 @@ class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): ...@@ -96,7 +96,7 @@ class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]: def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe: for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items()) # type: ignore[call-overload] yield from iter(mapping.values() if self.drop_key else mapping.items())
class Enumerator(IterDataPipe[Tuple[int, D]]): class Enumerator(IterDataPipe[Tuple[int, D]]):
...@@ -250,7 +250,7 @@ class TakerDataPipe(IterDataPipe): ...@@ -250,7 +250,7 @@ class TakerDataPipe(IterDataPipe):
return num_take return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe: def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]:
dp = IoPathFileLister(root=root) dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp) dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE) dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
...@@ -331,3 +331,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor: ...@@ -331,3 +331,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2) width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2) flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1)) return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return ShardingFilter(datapipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment