Unverified Commit 0cba9b78 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enforce shuffling before sharding (#5680)

* enforce shuffling before sharding

* revert test changes and add comment
parent 96aecd2d
......@@ -119,6 +119,9 @@ class TestCommon:
pickle.dumps(dataset)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
......
......@@ -112,10 +112,10 @@ that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and
`hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take
There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take
place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are
required in each dataset.
required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!).
......
......@@ -107,8 +107,8 @@ class Caltech101(Dataset):
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
anns_dp = Filter(anns_dp, self._is_ann)
......@@ -167,8 +167,8 @@ class Caltech256(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
......@@ -155,8 +155,8 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
anns_dp = Zipper(
*[
......
......@@ -85,8 +85,8 @@ class _CifarBase(Dataset):
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
......@@ -77,8 +77,8 @@ class CLEVR(Dataset):
)
images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
images_dp = hint_sharding(images_dp)
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if config.split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
......
......@@ -176,8 +176,8 @@ class Coco(Dataset):
images_dp, meta_dp = resource_dps
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_image)
meta_dp = Filter(
......@@ -206,8 +206,8 @@ class Coco(Dataset):
anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_dp = IterKeyZipper(
anns_meta_dp,
......
......@@ -46,8 +46,8 @@ class Country211(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split]))
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
......@@ -199,8 +199,8 @@ class CUB200(Dataset):
prepare_ann_fn = self._2010_prepare_ann
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = IterKeyZipper(
split_dp,
......
......@@ -2,16 +2,7 @@ import enum
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
Demultiplexer,
LineReader,
CSVParser,
)
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
......@@ -24,6 +15,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
path_comparator,
getitem,
hint_shuffling,
)
from torchvision.prototype.features import Label, EncodedImage
......@@ -98,7 +90,7 @@ class DTD(Dataset):
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")
......
......@@ -46,6 +46,6 @@ class EuroSAT(Dataset):
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
......@@ -54,6 +54,6 @@ class FER2013(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
......@@ -94,7 +94,7 @@ class GTSRB(Dataset):
ann_dp = CSVDictParser(ann_dp, delimiter=";")
dp = Zipper(images_dp, ann_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
......@@ -160,8 +160,8 @@ class ImageNet(Dataset):
if config.split == "train":
dp = TarArchiveLoader(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data)
else: # config.split == "val":
images_dp, devkit_dp = resource_dps
......@@ -176,8 +176,8 @@ class ImageNet(Dataset):
label_dp = LineReader(label_dp, decode=True, return_path=False)
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_sharding(label_dp)
label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp)
dp = IterKeyZipper(
label_dp,
......
......@@ -105,8 +105,8 @@ class _MNISTBase(Dataset):
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, functools.partial(self._prepare_sample, config=config))
......
......@@ -99,8 +99,8 @@ class OxfordIITPet(Dataset):
split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
)
split_and_classification_dp = hint_sharding(split_and_classification_dp)
split_and_classification_dp = hint_shuffling(split_and_classification_dp)
split_and_classification_dp = hint_sharding(split_and_classification_dp)
segmentations_dp = Filter(segmentations_dp, self._filter_segmentations)
......
......@@ -113,6 +113,6 @@ class PCAM(Dataset):
targets_dp = PCAMH5Reader(targets_dp, key="y")
dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
......@@ -106,8 +106,8 @@ class SBD(Dataset):
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
......
......@@ -48,6 +48,6 @@ class SEMEION(Dataset):
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
......@@ -81,8 +81,8 @@ class StanfordCars(Dataset):
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment