Commit 2f781c5a authored by Myle Ott's avatar Myle Ott
Browse files

Support different max_source_positions and max_target_positions

parent 5fe8ea46
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
import contextlib import contextlib
import itertools import itertools
import numbers
import numpy as np import numpy as np
import os import os
import torch import torch
...@@ -93,7 +94,7 @@ class LanguageDatasets(object): ...@@ -93,7 +94,7 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0, def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1, max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024, sample_without_replacement=0, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False, skip_invalid_size_inputs_valid_test=False,
sort_by_source_size=False): sort_by_source_size=False):
dataset = self.splits[split] dataset = self.splits[split]
...@@ -205,8 +206,20 @@ class LanguagePairDataset(object): ...@@ -205,8 +206,20 @@ class LanguagePairDataset(object):
return res return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 2 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 2 or dst_size > max_dst_positions):
return False
return True
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False): max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths """Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch.""" are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) assert isinstance(src, IndexedDataset)
...@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0 cur_max_size = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
if src.sizes[idx] < 2 or \ if not _valid_size(src.sizes[idx],
(False if dst is None else dst.sizes[idx] < 2) or \ None if dst is None else dst.sizes[idx],
sizes[idx] > max_positions: max_positions):
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception("Unable to handle input id {} of " raise Exception("Unable to handle input id {} of size {} / {}.".format(
"size {} / {}.".format(idx, src.sizes[idx], idx, src.sizes[idx], "none" if dst is None else dst.sizes[idx]))
"none" if dst is None else dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)): if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch yield batch
...@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
if len(ignored) > 0: if len(ignored) > 0:
print("Warning! {} samples are either too short or too long " print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored)) "and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=1024, sort_by_source_size=False): max_positions=(1024, 1024), sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches """Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths.""" may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
...@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, ...@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len = 0 sample_len = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \ if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
src.sizes[idx] > max_positions or \
dst.sizes[idx] > max_positions:
ignored.append(idx) ignored.append(idx)
continue continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx]) sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
...@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, ...@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
if len(ignored) > 0: if len(ignored) > 0:
print("Warning! {} samples are either too short or too long " print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored)) "and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
batches = list(make_batches()) batches = list(make_batches())
if not sort_by_source_size: if not sort_by_source_size:
......
...@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict): ...@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict):
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_positions, max_positions=args.max_source_positions,
) )
decoder = FConvDecoder( decoder = FConvDecoder(
dst_dict, dst_dict,
...@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict): ...@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict):
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_positions, max_positions=args.max_target_positions,
) )
return FConvModel(encoder, decoder) return FConvModel(encoder, decoder)
...@@ -33,8 +33,10 @@ def add_dataset_args(parser): ...@@ -33,8 +33,10 @@ def add_dataset_args(parser):
help='target language') help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N', group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)') help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N', group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence') help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='Ignore too long or too short lines in valid and test set')
return group return group
......
...@@ -40,7 +40,7 @@ class SequenceGenerator(object): ...@@ -40,7 +40,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict) self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() for m in self.models]) self.maxlen = min(maxlen, *[m.max_decoder_positions() for m in self.models])
self.stop_early = stop_early self.stop_early = stop_early
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
......
...@@ -66,6 +66,11 @@ def main(): ...@@ -66,6 +66,11 @@ def main():
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict) criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions)
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion) trainer = MultiprocessingTrainer(args, model, criterion)
...@@ -89,11 +94,11 @@ def main(): ...@@ -89,11 +94,11 @@ def main():
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch: while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch # train for one epoch
train(args, epoch, batch_offset, trainer, dataset, num_gpus) train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
# evaluate on validate set # evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus) val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
if k == 0: if k == 0:
if not args.no_save: if not args.no_save:
# save checkpoint # save checkpoint
...@@ -117,15 +122,15 @@ def get_perplexity(loss): ...@@ -117,15 +122,15 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus): def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
"""Train the model for one epoch.""" """Train the model for one epoch."""
torch.manual_seed(args.seed + epoch) torch.manual_seed(args.seed + epoch)
trainer.set_seed(args.seed + epoch) trainer.set_seed(args.seed + epoch)
itr = dataset.dataloader(args.train_subset, num_workers=args.workers, itr = dataset.dataloader(
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens,
max_positions=args.max_positions, seed=args.seed, epoch=epoch, max_positions=max_positions,
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum)) sort_by_source_size=(epoch <= args.curriculum))
...@@ -207,12 +212,11 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): ...@@ -207,12 +212,11 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state) trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus): def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None, itr = dataset.dataloader(
max_tokens=args.max_tokens, subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment