Commit 820f796f authored by Myle Ott's avatar Myle Ott
Browse files

Add `--curriculum` option

parent 3af8ec82
...@@ -94,7 +94,8 @@ class LanguageDatasets(object): ...@@ -94,7 +94,8 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0, def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1, max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024, sample_without_replacement=0, max_positions=1024,
skip_invalid_size_inputs_valid_test=False): skip_invalid_size_inputs_valid_test=False,
sort_by_source_size=False):
dataset = self.splits[split] dataset = self.splits[split]
if split.startswith('train'): if split.startswith('train'):
with numpy_seed(seed): with numpy_seed(seed):
...@@ -102,7 +103,8 @@ class LanguageDatasets(object): ...@@ -102,7 +103,8 @@ class LanguageDatasets(object):
dataset.src, dataset.dst, dataset.src, dataset.dst,
max_tokens=max_tokens, epoch=epoch, max_tokens=max_tokens, epoch=epoch,
sample=sample_without_replacement, sample=sample_without_replacement,
max_positions=max_positions) max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'): elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst, batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions, max_positions=max_positions,
...@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
yield batch yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_positions=1024): def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=1024, sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches """Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths.""" may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
...@@ -310,7 +313,8 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p ...@@ -310,7 +313,8 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"and will be ignored, sample ids={}".format(len(ignored), ignored)) "and will be ignored, sample ids={}".format(len(ignored), ignored))
batches = list(make_batches()) batches = list(make_batches())
np.random.shuffle(batches) if not sort_by_source_size:
np.random.shuffle(batches)
if sample: if sample:
offset = (epoch - 1) * sample offset = (epoch - 1) * sample
...@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p ...@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"batch length is not correct {}".format(len(result)) "batch length is not correct {}".format(len(result))
batches = result batches = result
else:
for _ in range(epoch - 1):
np.random.shuffle(batches)
return batches return batches
......
...@@ -67,6 +67,8 @@ def add_optimization_args(parser): ...@@ -67,6 +67,8 @@ def add_optimization_args(parser):
help='If bigger than 0, use that number of mini-batches for each epoch,' help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the' ' where each sample is drawn randomly without replacement from the'
' dataset') ' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
return group return group
......
...@@ -120,11 +120,15 @@ def get_perplexity(loss): ...@@ -120,11 +120,15 @@ def get_perplexity(loss):
def train(args, epoch, batch_offset, trainer, dataset, num_gpus): def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch.""" """Train the model for one epoch."""
torch.manual_seed(args.seed + epoch)
trainer.set_seed(args.seed + epoch)
itr = dataset.dataloader(args.train_subset, num_workers=args.workers, itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions, max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter() loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch wpb_meter = AverageMeter() # words per batch
...@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch) desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr() lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t: with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment