"src/diffusers/models/modeling_utils.py" did not exist on "df90f0ce989dcccd7ef2fe9ff085da3197b2f2ad"
Commit f9362e87 authored by Myle Ott's avatar Myle Ott
Browse files

Output correct perplexity when training with --sentence-avg

parent 81ace092
......@@ -33,6 +33,7 @@ class CrossEntropyCriterion(FairseqCriterion):
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -40,7 +41,12 @@ class CrossEntropyCriterion(FairseqCriterion):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
......@@ -70,6 +70,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -77,8 +78,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment