"vscode:/vscode.git/clone" did not exist on "7b18556c4aee85e3905444b4efa93842c727bc1e"
Commit 820f796f authored by Myle Ott's avatar Myle Ott
Browse files

Add `--curriculum` option

parent 3af8ec82
......@@ -94,7 +94,8 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024,
skip_invalid_size_inputs_valid_test=False):
skip_invalid_size_inputs_valid_test=False,
sort_by_source_size=False):
dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed):
......@@ -102,7 +103,8 @@ class LanguageDatasets(object):
dataset.src, dataset.dst,
max_tokens=max_tokens, epoch=epoch,
sample=sample_without_replacement,
max_positions=max_positions)
max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions,
......@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_positions=1024):
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=1024, sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
......@@ -310,6 +313,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"and will be ignored, sample ids={}".format(len(ignored), ignored))
batches = list(make_batches())
if not sort_by_source_size:
np.random.shuffle(batches)
if sample:
......@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"batch length is not correct {}".format(len(result))
batches = result
else:
for _ in range(epoch - 1):
np.random.shuffle(batches)
return batches
......
......@@ -67,6 +67,8 @@ def add_optimization_args(parser):
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
return group
......
......@@ -120,11 +120,15 @@ def get_perplexity(loss):
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch."""
torch.manual_seed(args.seed + epoch)
trainer.set_seed(args.seed + epoch)
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
......@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment