Unverified Commit ac561bce authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

hint shuffling in prototype datasets rather than acutally applying it (#5111)

* hint shuffling rather actually shuffle prototype datasets

* cleanup after merge
parent eac3dc7b
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
Filter, Filter,
IterKeyZipper, IterKeyZipper,
) )
...@@ -20,7 +19,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -20,7 +19,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, Feature from torchvision.prototype.features import Label, BoundingBox, Feature
...@@ -121,7 +120,7 @@ class Caltech101(Dataset): ...@@ -121,7 +120,7 @@ class Caltech101(Dataset):
images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp) images_dp = hint_sharding(images_dp)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) images_dp = hint_shuffling(images_dp)
anns_dp = Filter(anns_dp, self._is_ann) anns_dp = Filter(anns_dp, self._is_ann)
...@@ -185,7 +184,7 @@ class Caltech256(Dataset): ...@@ -185,7 +184,7 @@ class Caltech256(Dataset):
dp = resource_dps[0] dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
Filter, Filter,
Zipper, Zipper,
IterKeyZipper, IterKeyZipper,
...@@ -19,7 +18,13 @@ from torchvision.prototype.datasets.utils import ( ...@@ -19,7 +18,13 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
...@@ -152,7 +157,7 @@ class CelebA(Dataset): ...@@ -152,7 +157,7 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = hint_sharding(splits_dp) splits_dp = hint_sharding(splits_dp)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) splits_dp = hint_shuffling(splits_dp)
anns_dp = Zipper( anns_dp = Zipper(
*[ *[
......
...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Filter, Filter,
Mapper, Mapper,
Shuffler,
) )
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
...@@ -23,7 +22,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -23,7 +22,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, hint_shuffling,
image_buffer_from_array, image_buffer_from_array,
path_comparator, path_comparator,
hint_sharding, hint_sharding,
...@@ -89,7 +88,7 @@ class _CifarBase(Dataset): ...@@ -89,7 +88,7 @@ class _CifarBase(Dataset):
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
Filter, Filter,
Demultiplexer, Demultiplexer,
Grouper, Grouper,
...@@ -31,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -31,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
path_accessor, path_accessor,
hint_sharding, hint_sharding,
hint_shuffling,
) )
from torchvision.prototype.features import BoundingBox, Label, Feature from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
...@@ -182,7 +182,7 @@ class Coco(Dataset): ...@@ -182,7 +182,7 @@ class Coco(Dataset):
if config.annotations is None: if config.annotations is None:
dp = hint_sharding(images_dp) dp = hint_sharding(images_dp)
dp = Shuffler(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = Filter( meta_dp = Filter(
...@@ -208,7 +208,7 @@ class Coco(Dataset): ...@@ -208,7 +208,7 @@ class Coco(Dataset):
anns_meta_dp = UnBatcher(anns_meta_dp) anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp) anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = Shuffler(anns_meta_dp) anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_dp = IterKeyZipper( anns_dp = IterKeyZipper(
anns_meta_dp, anns_meta_dp,
......
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
from typing import Any, Callable, Dict, List, Optional, Tuple, cast from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import torch import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter, Shuffler from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig, DatasetConfig,
...@@ -21,6 +21,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -21,6 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
read_mat, read_mat,
hint_sharding, hint_sharding,
hint_shuffling,
) )
from torchvision.prototype.features import Label from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
...@@ -141,7 +142,7 @@ class ImageNet(Dataset): ...@@ -141,7 +142,7 @@ class ImageNet(Dataset):
# the train archive is a tar of tars # the train archive is a tar of tars
dp = TarArchiveReader(images_dp) dp = TarArchiveReader(images_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_train_data) dp = Mapper(dp, self._collate_train_data)
elif config.split == "val": elif config.split == "val":
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
...@@ -149,7 +150,7 @@ class ImageNet(Dataset): ...@@ -149,7 +150,7 @@ class ImageNet(Dataset):
devkit_dp = Mapper(devkit_dp, int) devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1) devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp) devkit_dp = hint_sharding(devkit_dp)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE) devkit_dp = hint_shuffling(devkit_dp)
dp = IterKeyZipper( dp = IterKeyZipper(
devkit_dp, devkit_dp,
...@@ -161,7 +162,7 @@ class ImageNet(Dataset): ...@@ -161,7 +162,7 @@ class ImageNet(Dataset):
dp = Mapper(dp, self._collate_val_data) dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test" else: # config.split == "test"
dp = hint_sharding(images_dp) dp = hint_sharding(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data) dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
...@@ -12,7 +12,6 @@ from torchdata.datapipes.iter import ( ...@@ -12,7 +12,6 @@ from torchdata.datapipes.iter import (
Demultiplexer, Demultiplexer,
Mapper, Mapper,
Zipper, Zipper,
Shuffler,
) )
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
fromfile, fromfile,
hint_sharding, hint_sharding,
hint_shuffling,
) )
from torchvision.prototype.features import Image, Label from torchvision.prototype.features import Image, Label
...@@ -135,7 +135,7 @@ class _MNISTBase(Dataset): ...@@ -135,7 +135,7 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
Demultiplexer, Demultiplexer,
Filter, Filter,
IterKeyZipper, IterKeyZipper,
...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
path_comparator, path_comparator,
hint_sharding, hint_sharding,
hint_shuffling,
) )
...@@ -141,7 +141,7 @@ class SBD(Dataset): ...@@ -141,7 +141,7 @@ class SBD(Dataset):
split_dp = extra_split_dp split_dp = extra_split_dp
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp) split_dp = hint_shuffling(split_dp)
dp = split_dp dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)): for level, data_dp in enumerate((images_dp, anns_dp)):
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
CSVParser, CSVParser,
) )
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
...@@ -17,7 +16,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -17,7 +16,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
class SEMEION(Dataset): class SEMEION(Dataset):
...@@ -65,6 +64,6 @@ class SEMEION(Dataset): ...@@ -65,6 +64,6 @@ class SEMEION(Dataset):
dp = resource_dps[0] dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp return dp
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
Shuffler,
Filter, Filter,
Demultiplexer, Demultiplexer,
IterKeyZipper, IterKeyZipper,
...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_comparator, path_comparator,
hint_sharding, hint_sharding,
hint_shuffling,
) )
HERE = pathlib.Path(__file__).parent HERE = pathlib.Path(__file__).parent
...@@ -131,7 +131,7 @@ class VOC(Dataset): ...@@ -131,7 +131,7 @@ class VOC(Dataset):
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE) split_dp = hint_shuffling(split_dp)
dp = split_dp dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)): for level, data_dp in enumerate((images_dp, anns_dp)):
......
...@@ -30,7 +30,7 @@ import PIL.Image ...@@ -30,7 +30,7 @@ import PIL.Image
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.utils.data import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
...@@ -335,3 +335,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor: ...@@ -335,3 +335,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return ShardingFilter(datapipe) return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment