Commit 5c241c8c authored by Spencer Poff's avatar Spencer Poff Committed by Facebook Github Bot
Browse files

support streaming iterator

Summary:
For tasks that involve streaming data directly from an API, we need a simpler epoch iterator.

Also included in this change is support for initializing a dictionary with an arbitrary list of special symbols.

Reviewed By: myleott

Differential Revision: D16110603

fbshipit-source-id: be6d9f680292dec1512614871f9269c95ac84861
parent bccfddbb
...@@ -19,7 +19,14 @@ from fairseq.data import data_utils ...@@ -19,7 +19,14 @@ from fairseq.data import data_utils
class Dictionary(object): class Dictionary(object):
"""A mapping from symbols to consecutive integers""" """A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'): def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
bos='<s>',
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = [] self.symbols = []
self.count = [] self.count = []
...@@ -28,6 +35,9 @@ class Dictionary(object): ...@@ -28,6 +35,9 @@ class Dictionary(object):
self.pad_index = self.add_symbol(pad) self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos) self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk) self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols) self.nspecial = len(self.symbols)
def __eq__(self, other): def __eq__(self, other):
...@@ -44,6 +54,7 @@ class Dictionary(object): ...@@ -44,6 +54,7 @@ class Dictionary(object):
def index(self, sym): def index(self, sym):
"""Returns the index of the specified symbol""" """Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices: if sym in self.indices:
return self.indices[sym] return self.indices[sym]
return self.unk_index return self.unk_index
...@@ -169,33 +180,41 @@ class Dictionary(object): ...@@ -169,33 +180,41 @@ class Dictionary(object):
... ...
``` ```
""" """
d = cls()
d.add_from_file(f, ignore_utf_errors)
return d
def add_from_file(self, f, ignore_utf_errors=False):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str): if isinstance(f, str):
try: try:
if not ignore_utf_errors: if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd: with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd) self.add_from_file(fd)
else: else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd: with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd) self.add_from_file(fd)
except FileNotFoundError as fnfe: except FileNotFoundError as fnfe:
raise fnfe raise fnfe
except UnicodeError: except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please " raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)) "rebuild the dataset".format(f))
return
d = cls()
lines = f.readlines() lines = f.readlines()
indices_start_line = d._load_meta(lines) indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]: for line in lines[indices_start_line:]:
idx = line.rfind(' ') idx = line.rfind(' ')
if idx == -1: if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'") raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx] word = line[:idx]
count = int(line[idx + 1:]) count = int(line[idx + 1:])
d.indices[word] = len(d.symbols) self.indices[word] = len(self.symbols)
d.symbols.append(word) self.symbols.append(word)
d.count.append(count) self.count.append(count)
return d
def _save(self, f, kv_iterator): def _save(self, f, kv_iterator):
if isinstance(f, str): if isinstance(f, str):
......
...@@ -51,7 +51,69 @@ class CountingIterator(object): ...@@ -51,7 +51,69 @@ class CountingIterator(object):
return self return self
class EpochBatchIterator(object): class EpochBatchIterating(object):
def __len__(self) -> int:
raise NotImplementedError
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
raise NotImplementedError
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
raise NotImplementedError
@property
def iterations_in_epoch(self) -> int:
raise NotImplementedError
def state_dict(self):
raise NotImplementedError
def load_state_dict(self, state_dict):
raise NotImplementedError
class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
):
# assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.epoch = epoch
self._current_epoch_iterator = None
self.num_shards = num_shards
self.shard_id = shard_id
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
iterable=self.dataset,
num_shards=self.num_shards,
shard_id=self.shard_id,
),
)
return self._current_epoch_iterator
def end_of_epoch(self) -> bool:
return not self._current_epoch_iterator.has_next()
@property
def iterations_in_epoch(self) -> int:
if self._current_epoch_iterator is not None:
return self._current_epoch_iterator.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
class EpochBatchIterator(EpochBatchIterating):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator: Compared to :class:`torch.utils.data.DataLoader`, this iterator:
...@@ -121,7 +183,7 @@ class EpochBatchIterator(object): ...@@ -121,7 +183,7 @@ class EpochBatchIterator(object):
) )
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted""" """Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next() return not self._cur_epoch_itr.has_next()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment