"dgl_sparse/vscode:/vscode.git/clone" did not exist on "4cf5f682a65295ec10c3419cecdcec4388a2235f"
Commit 2f781c5a authored by Myle Ott's avatar Myle Ott
Browse files

Support different max_source_positions and max_target_positions

parent 5fe8ea46
......@@ -8,6 +8,7 @@
import contextlib
import itertools
import numbers
import numpy as np
import os
import torch
......@@ -93,7 +94,7 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024,
sample_without_replacement=0, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
sort_by_source_size=False):
dataset = self.splits[split]
......@@ -205,8 +206,20 @@ class LanguagePairDataset(object):
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 2 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 2 or dst_size > max_dst_positions):
return False
return True
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False):
max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
......@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0
ignored = []
for idx in indices:
if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions:
if not _valid_size(src.sizes[idx],
None if dst is None else dst.sizes[idx],
max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx],
"none" if dst is None else dst.sizes[idx]))
raise Exception("Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], "none" if dst is None else dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch
......@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
if len(batch) > 0:
yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=1024, sort_by_source_size=False):
max_positions=(1024, 1024), sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
......@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len = 0
ignored = []
for idx in indices:
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions or \
dst.sizes[idx] > max_positions:
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
......@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
batches = list(make_batches())
if not sort_by_source_size:
......
......@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict):
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_positions,
max_positions=args.max_source_positions,
)
decoder = FConvDecoder(
dst_dict,
......@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict):
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_positions,
max_positions=args.max_target_positions,
)
return FConvModel(encoder, decoder)
......@@ -33,8 +33,10 @@ def add_dataset_args(parser):
help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
return group
......
......@@ -40,7 +40,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() for m in self.models])
self.maxlen = min(maxlen, *[m.max_decoder_positions() for m in self.models])
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
......
......@@ -66,6 +66,11 @@ def main():
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions)
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion)
......@@ -89,11 +94,11 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, dataset, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
if k == 0:
if not args.no_save:
# save checkpoint
......@@ -117,18 +122,18 @@ def get_perplexity(loss):
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
"""Train the model for one epoch."""
torch.manual_seed(args.seed + epoch)
trainer.set_seed(args.seed + epoch)
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum))
itr = dataset.dataloader(
args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens,
seed=args.seed, epoch=epoch, max_positions=max_positions,
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
......@@ -207,13 +212,12 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus):
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None,
max_tokens=args.max_tokens,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
itr = dataset.dataloader(
subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment