Commit 9f7c3ec6 authored by Myle Ott's avatar Myle Ott
Browse files

Add support for sharded generation

parent cc7705d3
......@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield (idx, res)
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __len__(self):
return len(self.itr)
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
class LanguagePairDataset(object):
# padding constants
......
......@@ -15,7 +15,7 @@ import math
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import meters, nccl, utils
from fairseq import nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
......@@ -116,6 +116,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _build_lr_scheduler(self):
if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr
def anneal(e):
if e < self.args.force_anneal:
# use fixed LR schedule
......@@ -123,6 +124,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else:
next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None
else:
......
......@@ -23,6 +23,10 @@ def main():
help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)')
dataset_args.add_argument('--num-shards', default=1, type=int, metavar='N',
help='shard generation over N shards')
dataset_args.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)')
options.add_generation_args(parser)
args = parser.parse_args()
......@@ -72,6 +76,10 @@ def main():
itr = dataset.eval_dataloader(
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
num_sentences = 0
with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment