Commit 2a9b4ec2 authored by Spencer Poff's avatar Spencer Poff Committed by Facebook Github Bot
Browse files

More thorough support for iterable datasets

Summary: Using PyTorch IterableDataset for streaming iterators. Such that there is a clean differentiation in interface between datasets that are streaming data and those that support indexed access.

Reviewed By: myleott

Differential Revision: D18438694

fbshipit-source-id: 482857d8357091ea2a6bf819535b09ba7f1a5b7d
parent b31849aa
......@@ -5,7 +5,7 @@
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
......@@ -65,6 +65,7 @@ __all__ = [
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'FairseqIterableDataset',
'GroupedIterator',
'IdDataset',
'IndexedCachedDataset',
......
......@@ -7,7 +7,15 @@ import numpy as np
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
def set_epoch(self, epoch):
"""Will receive the updated epoch number at the beginning of the epoch.
"""
pass
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
......@@ -54,5 +62,11 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
def set_epoch(self, epoch):
pass
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
"""For datasets that need to be read sequentially, usually because the data
is being streamed or otherwise can't be manipulated on a single machine.
"""
def __iter__(self):
raise NotImplementedError
......@@ -63,6 +63,15 @@ class EpochBatchIterating(object):
raise NotImplementedError
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
"""
raise NotImplementedError
def end_of_epoch(self) -> bool:
......@@ -71,12 +80,15 @@ class EpochBatchIterating(object):
@property
def iterations_in_epoch(self) -> int:
"""The number of consumed batches in the current epoch."""
raise NotImplementedError
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
raise NotImplementedError
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
raise NotImplementedError
......@@ -84,7 +96,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
):
# assert isinstance(dataset, torch.utils.data.Dataset)
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
self.epoch = epoch
self._current_epoch_iterator = None
......@@ -93,6 +105,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
self.dataset.set_epoch(self.epoch)
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
iterable=self.dataset,
......
......@@ -12,6 +12,10 @@ class ListDataset(BaseWrapperDataset):
super().__init__(dataset)
self._sizes = sizes
def __iter__(self):
for x in self.dataset:
yield x
def collater(self, samples):
return samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment