Commit e4edf27a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Improve init speed of TokenBlockDataset and EpochBatchIterator

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/704

Differential Revision: D15221549

Pulled By: myleott

fbshipit-source-id: b0021acdc2d7792ce51421f1432e1f2bd8218f7b
parent 8d9063fe
...@@ -26,11 +26,11 @@ class CountingIterator(object): ...@@ -26,11 +26,11 @@ class CountingIterator(object):
count (int): number of elements consumed from this iterator count (int): number of elements consumed from this iterator
""" """
def __init__(self, iterable): def __init__(self, iterable, start=0):
self.iterable = iterable self.iterable = iterable
self.count = 0 self.count = start
self.itr = iter(self) self.itr = iter(self)
self.len = len(iterable) self.len = start + len(iterable)
def __len__(self): def __len__(self):
return self.len return self.len
...@@ -50,7 +50,6 @@ class CountingIterator(object): ...@@ -50,7 +50,6 @@ class CountingIterator(object):
def skip(self, num_to_skip): def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements.""" """Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
self.len -= num_to_skip
return self return self
...@@ -149,11 +148,13 @@ class EpochBatchIterator(object): ...@@ -149,11 +148,13 @@ class EpochBatchIterator(object):
itr_pos = state_dict.get('iterations_in_epoch', 0) itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0: if itr_pos > 0:
# fast-forward epoch iterator # fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True)) self._next_epoch_itr = self._get_iterator_for_epoch(
if itr_pos < len(itr): self.epoch,
self._next_epoch_itr = itr.skip(itr_pos) shuffle=state_dict.get('shuffle', True),
offset=itr_pos,
)
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False): def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
def shuffle_batches(batches, seed): def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get # set seed based on the seed and epoch number so that we get
...@@ -169,25 +170,33 @@ class EpochBatchIterator(object): ...@@ -169,25 +170,33 @@ class EpochBatchIterator(object):
batches = shuffle_batches(list(batches), self.seed + epoch) batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator( batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[])) batches, self.num_shards, self.shard_id, fill_value=[]
))
self.dataset.prefetch([i for s in batches for i in s]) self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus: if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else: else:
if shuffle: if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else: else:
batches = self.frozen_batches batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
return CountingIterator(torch.utils.data.DataLoader( ))
self.dataset,
collate_fn=self.collate_fn, if offset > 0 and offset >= len(batches):
batch_sampler=batches, return None
num_workers=self.num_workers,
)) return CountingIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
),
start=offset,
)
class GroupedIterator(object): class GroupedIterator(object):
......
...@@ -67,38 +67,50 @@ class TokenBlockDataset(FairseqDataset): ...@@ -67,38 +67,50 @@ class TokenBlockDataset(FairseqDataset):
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
self.slice_indices = np.empty((len(sizes), 2), dtype=int) self.slice_indices = np.empty((len(sizes), 2), dtype=int)
curr = 0 if not torch.is_tensor(sizes):
for i, sz in enumerate(sizes): sizes = torch.tensor(sizes)
self.slice_indices[i] = (curr, curr + sz) cumsum = torch.cumsum(sizes, dim=0)
curr += sz self.slice_indices[0, 1] = sizes[0]
self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int) self.slice_indices = np.array(self.slice_indices, dtype=int)
self.sizes = self.slice_indices[:, 1] - self.slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices # build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int) if break_mode == 'eos':
ds_idx, ds_remaining = -1, 0 # much faster version for eos break mode
for i, (s, e) in enumerate(self.slice_indices): self.block_to_dataset_index = np.stack(
to_consume = e - s [
if ds_remaining == 0: np.arange(len(sizes)), # starting index in dataset
ds_idx += 1 np.zeros(len(sizes), dtype=np.long), # starting offset within starting index
ds_remaining = sizes[ds_idx] np.arange(len(sizes)) # ending index in dataset
start_ds_idx = ds_idx ],
start_offset = sizes[ds_idx] - ds_remaining 1,
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
) )
assert ds_remaining == 0 else:
assert ds_idx == len(self.dataset) - 1 self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
)
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
def __getitem__(self, index): def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
......
...@@ -23,9 +23,9 @@ class TestTokenBlockDataset(unittest.TestCase): ...@@ -23,9 +23,9 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_eos_break_mode(self): def test_eos_break_mode(self):
data = [ data = [
torch.LongTensor([5, 4, 3, 2, 1]), torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.LongTensor([1]), # this should be filtered torch.tensor([1], dtype=torch.long),
torch.LongTensor([8, 7, 6, 1]), torch.tensor([8, 7, 6, 1], dtype=torch.long),
] ]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
...@@ -33,9 +33,9 @@ class TestTokenBlockDataset(unittest.TestCase): ...@@ -33,9 +33,9 @@ class TestTokenBlockDataset(unittest.TestCase):
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1]) self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
data = [ data = [
torch.LongTensor([5, 4, 3, 2, 1]), torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.LongTensor([8, 7, 6, 1]), torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.LongTensor([1]), # this should be filtered torch.tensor([1], dtype=torch.long),
] ]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
...@@ -44,9 +44,9 @@ class TestTokenBlockDataset(unittest.TestCase): ...@@ -44,9 +44,9 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_block_break_mode(self): def test_block_break_mode(self):
data = [ data = [
torch.LongTensor([5, 4, 3, 2, 1]), torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.LongTensor([8, 7, 6, 1]), torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.LongTensor([9, 1]), torch.tensor([9, 1], dtype=torch.long),
] ]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none') ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
self.assertEqual(ds[0].tolist(), [5, 4, 3]) self.assertEqual(ds[0].tolist(), [5, 4, 3])
...@@ -56,19 +56,19 @@ class TestTokenBlockDataset(unittest.TestCase): ...@@ -56,19 +56,19 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_complete_break_mode(self): def test_complete_break_mode(self):
data = [ data = [
torch.LongTensor([5, 4, 3, 2, 1]), torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.LongTensor([8, 7, 6, 1]), torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.LongTensor([9, 1]), torch.tensor([9, 1], dtype=torch.long),
] ]
ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete') ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1]) self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
data = [ data = [
torch.LongTensor([4, 3, 2, 1]), torch.tensor([4, 3, 2, 1], dtype=torch.long),
torch.LongTensor([5, 1]), torch.tensor([5, 1], dtype=torch.long),
torch.LongTensor([1]), torch.tensor([1], dtype=torch.long),
torch.LongTensor([6, 1]), torch.tensor([6, 1], dtype=torch.long),
] ]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete') ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1]) self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
......
...@@ -85,6 +85,18 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -85,6 +85,18 @@ class TestLoadCheckpoint(unittest.TestCase):
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51) self.assertEqual(epoch_itr.iterations_in_epoch, 51)
for _ in range(150 - 52):
next(itr)
self.assertEqual(epoch_itr.iterations_in_epoch, 149)
self.assertTrue(itr.has_next())
next(itr)
self.assertFalse(itr.has_next())
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertTrue(itr.has_next())
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment