Unverified Commit ac561bce authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

hint shuffling in prototype datasets rather than acutally applying it (#5111)

* hint shuffling rather actually shuffle prototype datasets

* cleanup after merge
parent eac3dc7b
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
)
......@@ -20,7 +19,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, Feature
......@@ -121,7 +120,7 @@ class Caltech101(Dataset):
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = hint_shuffling(images_dp)
anns_dp = Filter(anns_dp, self._is_ann)
......@@ -185,7 +184,7 @@ class Caltech256(Dataset):
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
......@@ -6,7 +6,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Zipper,
IterKeyZipper,
......@@ -19,7 +18,13 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
......@@ -152,7 +157,7 @@ class CelebA(Dataset):
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_shuffling(splits_dp)
anns_dp = Zipper(
*[
......
......@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe,
Filter,
Mapper,
Shuffler,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
......@@ -23,7 +22,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_shuffling,
image_buffer_from_array,
path_comparator,
hint_sharding,
......@@ -89,7 +88,7 @@ class _CifarBase(Dataset):
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
......
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Demultiplexer,
Grouper,
......@@ -31,6 +30,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -182,7 +182,7 @@ class Coco(Dataset):
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = Shuffler(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = Filter(
......@@ -208,7 +208,7 @@ class Coco(Dataset):
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = Shuffler(anns_meta_dp)
anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_dp = IterKeyZipper(
anns_meta_dp,
......
......@@ -4,7 +4,7 @@ import re
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
......@@ -21,6 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
read_mat,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -141,7 +142,7 @@ class ImageNet(Dataset):
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_train_data)
elif config.split == "val":
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
......@@ -149,7 +150,7 @@ class ImageNet(Dataset):
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
devkit_dp = hint_shuffling(devkit_dp)
dp = IterKeyZipper(
devkit_dp,
......@@ -161,7 +162,7 @@ class ImageNet(Dataset):
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = hint_sharding(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
......@@ -12,7 +12,6 @@ from torchdata.datapipes.iter import (
Demultiplexer,
Mapper,
Zipper,
Shuffler,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
......@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
fromfile,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Image, Label
......@@ -135,7 +135,7 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
......
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Demultiplexer,
Filter,
IterKeyZipper,
......@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor,
path_comparator,
hint_sharding,
hint_shuffling,
)
......@@ -141,7 +141,7 @@ class SBD(Dataset):
split_dp = extra_split_dp
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp)
split_dp = hint_shuffling(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
......
......@@ -5,7 +5,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
CSVParser,
)
from torchvision.prototype.datasets.decoder import raw
......@@ -17,7 +16,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
class SEMEION(Dataset):
......@@ -65,6 +64,6 @@ class SEMEION(Dataset):
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp
......@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Demultiplexer,
IterKeyZipper,
......@@ -29,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
path_comparator,
hint_sharding,
hint_shuffling,
)
HERE = pathlib.Path(__file__).parent
......@@ -131,7 +131,7 @@ class VOC(Dataset):
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
split_dp = hint_shuffling(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
......
......@@ -30,7 +30,7 @@ import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
......@@ -335,3 +335,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment