Commit 24d7de44 authored by Myle Ott's avatar Myle Ott
Browse files

Unify various sharding into ShardedIterator

parent 76b5ecab
...@@ -44,11 +44,7 @@ def main(args): ...@@ -44,11 +44,7 @@ def main(args):
max_positions=args.max_target_positions or 1024, max_positions=args.max_target_positions or 1024,
descending=True, descending=True,
) )
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
scorer = SequenceScorer(models) scorer = SequenceScorer(models)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import contextlib import contextlib
import glob import glob
import itertools
import math import math
import numbers import numbers
import numpy as np import numpy as np
...@@ -50,21 +51,31 @@ def fmt_path(path, fmt, *args): ...@@ -50,21 +51,31 @@ def fmt_path(path, fmt, *args):
return os.path.join(path, fmt.format(*args)) return os.path.join(path, fmt.format(*args))
class sharded_iterator(object): class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, itr, num_shards, shard_id): def __init__(self, iterable, num_shards, shard_id, fill_value=None):
assert shard_id >= 0 and shard_id < num_shards if shard_id < 0 or shard_id >= num_shards:
self.itr = itr raise ValueError('shard_id must be between 0 and num_shards')
self.num_shards = num_shards
self.shard_id = shard_id self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self): def __len__(self):
return len(self.itr) return self._sharded_len
def __iter__(self): def __iter__(self):
for i, v in enumerate(self.itr): return self
if i % self.num_shards == self.shard_id:
yield v def __next__(self):
return next(self.itr)[1]
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
...@@ -195,18 +206,6 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -195,18 +206,6 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return batches return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager @contextlib.contextmanager
def numpy_seed(seed): def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and """Context manager which seeds the NumPy PRNG with the specified seed and
......
...@@ -10,7 +10,7 @@ import itertools ...@@ -10,7 +10,7 @@ import itertools
import numpy as np import numpy as np
import torch import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, mask_batches, batches_by_size from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, ShardedIterator, batches_by_size
class LanguageDatasets(object): class LanguageDatasets(object):
...@@ -41,7 +41,7 @@ class LanguageDatasets(object): ...@@ -41,7 +41,7 @@ class LanguageDatasets(object):
frozen_batches = tuple(batches) # freeze frozen_batches = tuple(batches) # freeze
def dataloader(b): def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset b = ShardedIterator(b, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1): for epoch in itertools.count(1):
...@@ -74,7 +74,7 @@ class LanguageDatasets(object): ...@@ -74,7 +74,7 @@ class LanguageDatasets(object):
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending, descending=descending,
allow_different_src_lens=True) allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards) batch_sampler = ShardedIterator(batch_sampler, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
\ No newline at end of file
...@@ -58,10 +58,7 @@ def main(args): ...@@ -58,10 +58,7 @@ def main(args):
max_positions=max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
) )
if args.num_shards > 1: itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id)
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment