Commit f442f896 authored by Myle Ott's avatar Myle Ott
Browse files

Add --max-sentence option for batching based on # sentences

parent 2ef422f6
......@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, padding_idx):
super().__init__()
self.padding_idx = padding_idx
def __init__(self, args, dst_dict):
super().__init__(args, dst_dict)
def forward(self, model, sample):
"""Compute the loss for the given sample.
......@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
sample_size = sample['ntokens']
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
......
......@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self):
def __init__(self, args, dst_dict):
super().__init__()
self.args = args
self.padding_idx = dst_dict.pad()
def forward(self, model, sample):
"""Compute the loss for the given sample.
......
......@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, eps, padding_idx=None, weights=None):
super().__init__()
self.eps = eps
self.padding_idx = padding_idx
def __init__(self, args, dst_dict, weights=None):
super().__init__(args, dst_dict)
self.eps = args.label_smoothing
self.weights = weights
def forward(self, model, sample):
......@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['ntokens']
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
......
......@@ -97,27 +97,26 @@ class LanguageDatasets(object):
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, num_workers=0, max_tokens=None,
max_positions=(1024, 1024), seed=None, epoch=1,
sample_without_replacement=0,
max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False):
dataset = self.splits[split]
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst,
max_tokens=max_tokens, epoch=epoch,
sample=sample_without_replacement,
max_positions=max_positions,
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, batch_size=1,
max_tokens=None, max_positions=(1024, 1024),
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split]
batch_sampler = list(batches_by_size(
dataset.src, dataset.dst, batch_size, max_tokens,
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader(
......@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions):
return True
def batches_by_size(src, dst, batch_size=None, max_tokens=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False):
batch = []
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == batch_size:
if len(batch) == max_sentences:
return True
if src.sizes[batch[0]] != src.sizes[next_idx]:
if num_tokens > max_tokens:
return True
if num_tokens >= max_tokens:
if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True
return False
cur_max_size = 0
sample_len = 0
ignored = []
for idx in indices:
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
......@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None,
"Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
yield batch
batch = []
cur_max_size = 0
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
cur_max_size = max(cur_max_size, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
if len(batch) > 0:
yield batch
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=(1024, 1024), sort_by_source_size=False):
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
epoch=1, sample=0, max_positions=(1024, 1024),
sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
......@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
def make_batches():
batch = []
sample_len = 0
ignored = []
for idx in indices:
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0 and (len(batch) + 1) * sample_len > max_tokens:
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True))
batches = list(make_batches())
if not sort_by_source_size:
np.random.shuffle(batches)
......
......@@ -71,6 +71,9 @@ def add_optimization_args(parser):
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
return group
......
......@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict):
def build_criterion(args, src_dict, dst_dict):
padding_idx = dst_dict.pad()
if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
return criterions.LabelSmoothedCrossEntropyCriterion(args, dst_dict)
else:
return criterions.CrossEntropyCriterion(padding_idx)
return criterions.CrossEntropyCriterion(args, dst_dict)
def torch_persistent_save(*args, **kwargs):
......
......@@ -68,7 +68,7 @@ def main():
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset, batch_size=args.batch_size, max_positions=max_positions,
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
......
......@@ -23,6 +23,8 @@ def main():
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch')
dataset_args.add_argument('--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
......@@ -59,7 +61,8 @@ def main():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
......@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
trainer.set_seed(seed)
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens,
args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
......@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, ntokens)
bsz_meter.update(src_size)
nsentences = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
......@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader(
subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions,
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment