Commit 24d7de44 authored by Myle Ott's avatar Myle Ott
Browse files

Unify various sharding into ShardedIterator

parent 76b5ecab
......@@ -44,11 +44,7 @@ def main(args):
max_positions=args.max_target_positions or 1024,
descending=True,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models)
......
......@@ -7,6 +7,7 @@
import contextlib
import glob
import itertools
import math
import numbers
import numpy as np
......@@ -50,21 +51,31 @@ def fmt_path(path, fmt, *args):
return os.path.join(path, fmt.format(*args))
class sharded_iterator(object):
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return len(self.itr)
return self._sharded_len
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
return self
def __next__(self):
return next(self.itr)[1]
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
......@@ -195,18 +206,6 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
......
......@@ -10,7 +10,7 @@ import itertools
import numpy as np
import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, mask_batches, batches_by_size
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, ShardedIterator, batches_by_size
class LanguageDatasets(object):
......@@ -41,7 +41,7 @@ class LanguageDatasets(object):
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
b = ShardedIterator(b, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
......@@ -74,7 +74,7 @@ class LanguageDatasets(object):
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
batch_sampler = ShardedIterator(batch_sampler, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
......@@ -58,10 +58,7 @@ def main(args):
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
# Initialize generator
gen_timer = StopwatchMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment