Unverified Commit 0b02d420 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add sharding filters in prototype datasets to help the data loader (#5093)

* add sharding filters in prototype datasets to help the data loader

* hint rather than apply

* fix sharding for folders

* fix import

* fix coco

* appease mypy
parent f5dca445
......@@ -3,8 +3,9 @@ import io
import builtin_dataset_mocks
import pytest
import torch
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str
......@@ -105,6 +106,20 @@ class TestCommon:
def test_traversable(self, dataset, mock_info):
traverse(dataset)
@dataset_parametrization()
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset, mock_info, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)
for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
break
else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
class TestQMNIST:
@pytest.mark.parametrize(
......
......@@ -20,7 +20,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
from torchvision.prototype.features import Label, BoundingBox, Feature
......@@ -120,6 +120,7 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = Filter(anns_dp, self._is_ann)
......@@ -183,6 +184,7 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
......@@ -19,7 +19,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
......@@ -151,6 +151,7 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = Zipper(
......
......@@ -26,6 +26,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
image_buffer_from_array,
path_comparator,
hint_sharding,
)
from torchvision.prototype.features import Label, Image
......@@ -87,6 +88,7 @@ class _CifarBase(Dataset):
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
......
......@@ -30,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import (
BUILTIN_DIR,
getitem,
path_accessor,
hint_sharding,
)
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -180,7 +181,8 @@ class Coco(Dataset):
images_dp, meta_dp = resource_dps
if config.annotations is None:
dp = Shuffler(images_dp)
dp = hint_sharding(images_dp)
dp = Shuffler(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = Filter(
......@@ -190,7 +192,7 @@ class Coco(Dataset):
)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp = MappingIterator(meta_dp)
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
images_meta_dp, anns_meta_dp = Demultiplexer(
meta_dp,
2,
......@@ -201,11 +203,12 @@ class Coco(Dataset):
images_meta_dp = Mapper(images_meta_dp, getitem(1))
images_meta_dp = UnBatcher(images_meta_dp)
images_meta_dp = Shuffler(images_meta_dp)
anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = Shuffler(anns_meta_dp)
anns_dp = IterKeyZipper(
anns_meta_dp,
......
......@@ -20,6 +20,7 @@ from torchvision.prototype.datasets.utils._internal import (
Enumerator,
getitem,
read_mat,
hint_sharding,
)
from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -139,6 +140,7 @@ class ImageNet(Dataset):
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
elif config.split == "val":
......@@ -146,6 +148,7 @@ class ImageNet(Dataset):
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = IterKeyZipper(
......@@ -157,7 +160,8 @@ class ImageNet(Dataset):
)
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_sharding(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
......@@ -28,10 +28,10 @@ from torchvision.prototype.datasets.utils._internal import (
Decompressor,
INFINITE_BUFFER_SIZE,
fromfile,
hint_sharding,
)
from torchvision.prototype.features import Image, Label
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul)
......@@ -134,6 +134,7 @@ class _MNISTBase(Dataset):
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
......
......@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
path_accessor,
path_comparator,
hint_sharding,
)
......@@ -139,6 +140,7 @@ class SBD(Dataset):
if config.split == "train_noval":
split_dp = extra_split_dp
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp)
dp = split_dp
......
......@@ -17,7 +17,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding
class SEMEION(Dataset):
......@@ -64,6 +64,7 @@ class SEMEION(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp
......@@ -28,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
INFINITE_BUFFER_SIZE,
path_comparator,
hint_sharding,
)
HERE = pathlib.Path(__file__).parent
......@@ -129,6 +130,7 @@ class VOC(Dataset):
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = split_dp
......
......@@ -9,7 +9,7 @@ import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding
__all__ = ["from_data_folder", "from_image_folder"]
......@@ -51,6 +51,7 @@ def from_data_folder(
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp)
return (
......
......@@ -186,7 +186,7 @@ class Dataset(abc.ABC):
if use_sharded_dataset() and self.supports_sharded():
root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size)
return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return]
self.info.check_dependencies()
resource_dps = [
......
......@@ -30,8 +30,7 @@ import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter
from torchdata.datapipes.utils import StreamWrapper
......@@ -49,6 +48,7 @@ __all__ = [
"Decompressor",
"fromfile",
"read_flo",
"hint_sharding",
]
K = TypeVar("K")
......@@ -96,7 +96,7 @@ class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items()) # type: ignore[call-overload]
yield from iter(mapping.values() if self.drop_key else mapping.items())
class Enumerator(IterDataPipe[Tuple[int, D]]):
......@@ -250,7 +250,7 @@ class TakerDataPipe(IterDataPipe):
return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]:
dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
......@@ -331,3 +331,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return ShardingFilter(datapipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment