Unverified Commit 0cba9b78 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enforce shuffling before sharding (#5680)

* enforce shuffling before sharding

* revert test changes and add comment
parent 96aecd2d
...@@ -119,6 +119,9 @@ class TestCommon: ...@@ -119,6 +119,9 @@ class TestCommon:
pickle.dumps(dataset) pickle.dumps(dataset)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
......
...@@ -112,10 +112,10 @@ that is only needed at runtime. ...@@ -112,10 +112,10 @@ that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images. trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take `hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take
place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are
required in each dataset. required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!). names (yet!).
......
...@@ -107,8 +107,8 @@ class Caltech101(Dataset): ...@@ -107,8 +107,8 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp) images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
anns_dp = Filter(anns_dp, self._is_ann) anns_dp = Filter(anns_dp, self._is_ann)
...@@ -167,8 +167,8 @@ class Caltech256(Dataset): ...@@ -167,8 +167,8 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
...@@ -155,8 +155,8 @@ class CelebA(Dataset): ...@@ -155,8 +155,8 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp) splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
anns_dp = Zipper( anns_dp = Zipper(
*[ *[
......
...@@ -85,8 +85,8 @@ class _CifarBase(Dataset): ...@@ -85,8 +85,8 @@ class _CifarBase(Dataset):
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
...@@ -77,8 +77,8 @@ class CLEVR(Dataset): ...@@ -77,8 +77,8 @@ class CLEVR(Dataset):
) )
images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp) images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if config.split != "test": if config.split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
......
...@@ -176,8 +176,8 @@ class Coco(Dataset): ...@@ -176,8 +176,8 @@ class Coco(Dataset):
images_dp, meta_dp = resource_dps images_dp, meta_dp = resource_dps
if config.annotations is None: if config.annotations is None:
dp = hint_sharding(images_dp) dp = hint_shuffling(images_dp)
dp = hint_shuffling(dp) dp = hint_sharding(dp)
return Mapper(dp, self._prepare_image) return Mapper(dp, self._prepare_image)
meta_dp = Filter( meta_dp = Filter(
...@@ -206,8 +206,8 @@ class Coco(Dataset): ...@@ -206,8 +206,8 @@ class Coco(Dataset):
anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp) anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = hint_shuffling(anns_meta_dp) anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_dp = IterKeyZipper( anns_dp = IterKeyZipper(
anns_meta_dp, anns_meta_dp,
......
...@@ -46,8 +46,8 @@ class Country211(Dataset): ...@@ -46,8 +46,8 @@ class Country211(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split]))
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
...@@ -199,8 +199,8 @@ class CUB200(Dataset): ...@@ -199,8 +199,8 @@ class CUB200(Dataset):
prepare_ann_fn = self._2010_prepare_ann prepare_ann_fn = self._2010_prepare_ann
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = IterKeyZipper( dp = IterKeyZipper(
split_dp, split_dp,
......
...@@ -2,16 +2,7 @@ import enum ...@@ -2,16 +2,7 @@ import enum
import pathlib import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
Demultiplexer,
LineReader,
CSVParser,
)
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig, DatasetConfig,
...@@ -24,6 +15,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -24,6 +15,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
path_comparator, path_comparator,
getitem, getitem,
hint_shuffling,
) )
from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.features import Label, EncodedImage
...@@ -98,7 +90,7 @@ class DTD(Dataset): ...@@ -98,7 +90,7 @@ class DTD(Dataset):
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False) splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp) splits_dp = hint_sharding(splits_dp)
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")
......
...@@ -46,6 +46,6 @@ class EuroSAT(Dataset): ...@@ -46,6 +46,6 @@ class EuroSAT(Dataset):
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
...@@ -54,6 +54,6 @@ class FER2013(Dataset): ...@@ -54,6 +54,6 @@ class FER2013(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVDictParser(dp) dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
...@@ -94,7 +94,7 @@ class GTSRB(Dataset): ...@@ -94,7 +94,7 @@ class GTSRB(Dataset):
ann_dp = CSVDictParser(ann_dp, delimiter=";") ann_dp = CSVDictParser(ann_dp, delimiter=";")
dp = Zipper(images_dp, ann_dp) dp = Zipper(images_dp, ann_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
...@@ -160,8 +160,8 @@ class ImageNet(Dataset): ...@@ -160,8 +160,8 @@ class ImageNet(Dataset):
if config.split == "train": if config.split == "train":
dp = TarArchiveLoader(dp) dp = TarArchiveLoader(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data)
else: # config.split == "val": else: # config.split == "val":
images_dp, devkit_dp = resource_dps images_dp, devkit_dp = resource_dps
...@@ -176,8 +176,8 @@ class ImageNet(Dataset): ...@@ -176,8 +176,8 @@ class ImageNet(Dataset):
label_dp = LineReader(label_dp, decode=True, return_path=False) label_dp = LineReader(label_dp, decode=True, return_path=False)
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_sharding(label_dp)
label_dp = hint_shuffling(label_dp) label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp)
dp = IterKeyZipper( dp = IterKeyZipper(
label_dp, label_dp,
......
...@@ -105,8 +105,8 @@ class _MNISTBase(Dataset): ...@@ -105,8 +105,8 @@ class _MNISTBase(Dataset):
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, functools.partial(self._prepare_sample, config=config)) return Mapper(dp, functools.partial(self._prepare_sample, config=config))
......
...@@ -99,8 +99,8 @@ class OxfordIITPet(Dataset): ...@@ -99,8 +99,8 @@ class OxfordIITPet(Dataset):
split_and_classification_dp = CSVDictParser( split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
) )
split_and_classification_dp = hint_sharding(split_and_classification_dp)
split_and_classification_dp = hint_shuffling(split_and_classification_dp) split_and_classification_dp = hint_shuffling(split_and_classification_dp)
split_and_classification_dp = hint_sharding(split_and_classification_dp)
segmentations_dp = Filter(segmentations_dp, self._filter_segmentations) segmentations_dp = Filter(segmentations_dp, self._filter_segmentations)
......
...@@ -113,6 +113,6 @@ class PCAM(Dataset): ...@@ -113,6 +113,6 @@ class PCAM(Dataset):
targets_dp = PCAMH5Reader(targets_dp, key="y") targets_dp = PCAMH5Reader(targets_dp, key="y")
dp = Zipper(images_dp, targets_dp) dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
...@@ -106,8 +106,8 @@ class SBD(Dataset): ...@@ -106,8 +106,8 @@ class SBD(Dataset):
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)): for level, data_dp in enumerate((images_dp, anns_dp)):
......
...@@ -48,6 +48,6 @@ class SEMEION(Dataset): ...@@ -48,6 +48,6 @@ class SEMEION(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
...@@ -81,8 +81,8 @@ class StanfordCars(Dataset): ...@@ -81,8 +81,8 @@ class StanfordCars(Dataset):
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp) targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp) dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment