Commit 8b4c45a2 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Prevent math overflow when loss is too high

parent ff2e8cf2
...@@ -110,6 +110,13 @@ def main(): ...@@ -110,6 +110,13 @@ def main():
trainer.stop() trainer.stop()
def get_perplexity(loss):
try:
return math.pow(2, loss)
except OverflowError:
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus): def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch.""" """Train the model for one epoch."""
...@@ -162,7 +169,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -162,7 +169,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
save_checkpoint(trainer, args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format( fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, math.pow(2, loss_meter.avg)) loss_meter.avg, get_perplexity(loss_meter.avg))
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format( fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg)) round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format( fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
...@@ -227,7 +234,7 @@ def validate(args, epoch, trainer, dataset, subset, ngpus): ...@@ -227,7 +234,7 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
val_loss = loss_meter.avg val_loss = loss_meter.avg
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
val_loss, math.pow(2, val_loss)) val_loss, get_perplexity(val_loss))
fmt += ''.join( fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg) ' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment