Commit 5c241c8c authored by Spencer Poff's avatar Spencer Poff Committed by Facebook Github Bot
Browse files

support streaming iterator

Summary:
For tasks that involve streaming data directly from an API, we need a simpler epoch iterator.

Also included in this change is support for initializing a dictionary with an arbitrary list of special symbols.

Reviewed By: myleott

Differential Revision: D16110603

fbshipit-source-id: be6d9f680292dec1512614871f9269c95ac84861
parent bccfddbb
......@@ -19,7 +19,14 @@ from fairseq.data import data_utils
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
bos='<s>',
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
......@@ -28,6 +35,9 @@ class Dictionary(object):
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
......@@ -44,6 +54,7 @@ class Dictionary(object):
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
......@@ -169,33 +180,41 @@ class Dictionary(object):
...
```
"""
d = cls()
d.add_from_file(f, ignore_utf_errors)
return d
def add_from_file(self, f, ignore_utf_errors=False):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
self.add_from_file(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
return
d = cls()
lines = f.readlines()
indices_start_line = d._load_meta(lines)
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx + 1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
self.indices[word] = len(self.symbols)
self.symbols.append(word)
self.count.append(count)
def _save(self, f, kv_iterator):
if isinstance(f, str):
......
......@@ -51,7 +51,69 @@ class CountingIterator(object):
return self
class EpochBatchIterator(object):
class EpochBatchIterating(object):
def __len__(self) -> int:
raise NotImplementedError
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
raise NotImplementedError
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
raise NotImplementedError
@property
def iterations_in_epoch(self) -> int:
raise NotImplementedError
def state_dict(self):
raise NotImplementedError
def load_state_dict(self, state_dict):
raise NotImplementedError
class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
):
# assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.epoch = epoch
self._current_epoch_iterator = None
self.num_shards = num_shards
self.shard_id = shard_id
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
iterable=self.dataset,
num_shards=self.num_shards,
shard_id=self.shard_id,
),
)
return self._current_epoch_iterator
def end_of_epoch(self) -> bool:
return not self._current_epoch_iterator.has_next()
@property
def iterations_in_epoch(self) -> int:
if self._current_epoch_iterator is not None:
return self._current_epoch_iterator.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
class EpochBatchIterator(EpochBatchIterating):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
......@@ -121,7 +183,7 @@ class EpochBatchIterator(object):
)
return self._cur_epoch_itr
def end_of_epoch(self):
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment