Commit f442f896 authored by Myle Ott's avatar Myle Ott
Browse files

Add --max-sentence option for batching based on # sentences

parent 2ef422f6
...@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion ...@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, padding_idx): def __init__(self, args, dst_dict):
super().__init__() super().__init__(args, dst_dict)
self.padding_idx = padding_idx
def forward(self, model, sample): def forward(self, model, sample):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
input = net_output.view(-1, net_output.size(-1)) input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
sample_size = sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0],
'sample_size': sample_size, 'sample_size': sample_size,
......
...@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss ...@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self): def __init__(self, args, dst_dict):
super().__init__() super().__init__()
self.args = args
self.padding_idx = dst_dict.pad()
def forward(self, model, sample): def forward(self, model, sample):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function): ...@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, eps, padding_idx=None, weights=None): def __init__(self, args, dst_dict, weights=None):
super().__init__() super().__init__(args, dst_dict)
self.eps = eps self.eps = args.label_smoothing
self.padding_idx = padding_idx
self.weights = weights self.weights = weights
def forward(self, model, sample): def forward(self, model, sample):
...@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
input = F.log_softmax(net_output.view(-1, net_output.size(-1))) input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0],
'sample_size': sample_size, 'sample_size': sample_size,
......
...@@ -97,27 +97,26 @@ class LanguageDatasets(object): ...@@ -97,27 +97,26 @@ class LanguageDatasets(object):
assert self.src_dict.unk() == self.dst_dict.unk() assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, num_workers=0, max_tokens=None, def train_dataloader(self, split, num_workers=0, max_tokens=None,
max_positions=(1024, 1024), seed=None, epoch=1, max_sentences=None, max_positions=(1024, 1024),
sample_without_replacement=0, seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False): sort_by_source_size=False):
dataset = self.splits[split] dataset = self.splits[split]
with numpy_seed(seed): with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size( batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst, dataset.src, dataset.dst, max_tokens=max_tokens,
max_tokens=max_tokens, epoch=epoch, max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, sample=sample_without_replacement, max_positions=max_positions,
max_positions=max_positions,
sort_by_source_size=sort_by_source_size) sort_by_source_size=sort_by_source_size)
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, batch_size=1, def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_tokens=None, max_positions=(1024, 1024), max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False): skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split] dataset = self.splits[split]
batch_sampler = list(batches_by_size( batch_sampler = list(batches_by_size(
dataset.src, dataset.dst, batch_size, max_tokens, dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test)) ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
...@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions): ...@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions):
return True return True
def batches_by_size(src, dst, batch_size=None, max_tokens=None, def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
max_positions=(1024, 1024), ignore_invalid_inputs=False): ignore_invalid_inputs=False, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
batch = [] batch = []
def yield_batch(next_idx, num_tokens): def yield_batch(next_idx, num_tokens):
if len(batch) == 0: if len(batch) == 0:
return False return False
if len(batch) == batch_size: if len(batch) == max_sentences:
return True return True
if src.sizes[batch[0]] != src.sizes[next_idx]: if num_tokens > max_tokens:
return True return True
if num_tokens >= max_tokens: if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True return True
return False return False
cur_max_size = 0 sample_len = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions): if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
...@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None, ...@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None,
"Unable to handle input id {} of size {} / {}.".format( "Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], dst.sizes[idx])) idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)): sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
yield batch yield batch
batch = [] batch = []
cur_max_size = 0 sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx) batch.append(idx)
cur_max_size = max(cur_max_size, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0:
yield batch
if len(ignored) > 0: if len(ignored) > 0:
print("Warning! {} samples are either too short or too long " print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10])) "and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
if len(batch) > 0:
yield batch
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
max_positions=(1024, 1024), sort_by_source_size=False): def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
epoch=1, sample=0, max_positions=(1024, 1024),
sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches """Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths.""" may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src)) indices = np.random.permutation(len(src))
...@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, ...@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')] indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')] indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
def make_batches(): batches = list(_make_batches(
batch = [] src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0 ignore_invalid_inputs=True, allow_different_src_lens=True))
ignored = []
for idx in indices:
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0 and (len(batch) + 1) * sample_len > max_tokens:
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
batches = list(make_batches())
if not sort_by_source_size: if not sort_by_source_size:
np.random.shuffle(batches) np.random.shuffle(batches)
......
...@@ -71,6 +71,9 @@ def add_optimization_args(parser): ...@@ -71,6 +71,9 @@ def add_optimization_args(parser):
' dataset') ' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N', group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs') help='sort batches by source length for first N epochs')
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
return group return group
......
...@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict): ...@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict):
def build_criterion(args, src_dict, dst_dict): def build_criterion(args, src_dict, dst_dict):
padding_idx = dst_dict.pad()
if args.label_smoothing > 0: if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx) return criterions.LabelSmoothedCrossEntropyCriterion(args, dst_dict)
else: else:
return criterions.CrossEntropyCriterion(padding_idx) return criterions.CrossEntropyCriterion(args, dst_dict)
def torch_persistent_save(*args, **kwargs): def torch_persistent_save(*args, **kwargs):
......
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models) max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
args.gen_subset, batch_size=args.batch_size, max_positions=max_positions, args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with progress_bar(itr, smoothing=0, leave=False) as t:
......
...@@ -23,6 +23,8 @@ def main(): ...@@ -23,6 +23,8 @@ def main():
dataset_args = options.add_dataset_args(parser) dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N', dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
dataset_args.add_argument('--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT', dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)') help='data subset to use for training (train, valid, test)')
...@@ -59,7 +61,8 @@ def main(): ...@@ -59,7 +61,8 @@ def main():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens)) print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion # Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
...@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
trainer.set_seed(seed) trainer.set_seed(seed)
itr = dataset.train_dataloader( itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch, max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum)) sort_by_source_size=(epoch <= args.curriculum))
...@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample) nsentences = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, ntokens) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(src_size) bsz_meter.update(nsentences)
wpb_meter.update(ntokens) wpb_meter.update(ntokens)
wps_meter.update(ntokens) wps_meter.update(ntokens)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
...@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): ...@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions, subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment