Commit dd31fa92 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Report log likelihood for label smoothing

parent c5378602
......@@ -65,9 +65,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -78,4 +80,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
......@@ -150,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
......@@ -164,6 +165,11 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
......@@ -193,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(loss_meter.avg)),
('train ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)),
......@@ -242,16 +250,21 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
loss_meter.update(loss, ntokens)
extra_postfix = []
......@@ -265,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(loss_meter.avg)),
('valid ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment