Commit d9a13180 authored by Myle Ott's avatar Myle Ott
Browse files

Better failure message when loss explodes during FP16 training

parent a846b213
...@@ -130,6 +130,12 @@ class FP16Trainer(Trainer): ...@@ -130,6 +130,12 @@ class FP16Trainer(Trainer):
overflow = DynamicLossScaler.has_overflow(grad_norm) overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow) self.scaler.update_scale(overflow)
if overflow: if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm return grad_norm
......
...@@ -210,6 +210,8 @@ def add_optimization_args(parser): ...@@ -210,6 +210,8 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum loss scale (for FP16 training)')
return group return group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment