Commit 03ef3ab8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --fp16-scale-tolerance (#397)

Summary:
Let's only decrease the loss scale if a large enough percentage of batches overflow.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/397

Differential Revision: D13355159

Pulled By: myleott

fbshipit-source-id: e17dde73d34a639519b4348c013fdd19d2b314e6
parent 6c006a34
...@@ -12,19 +12,29 @@ from fairseq import optim, utils ...@@ -12,19 +12,29 @@ from fairseq import optim, utils
class DynamicLossScaler: class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000): def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05):
self.loss_scale = init_scale self.loss_scale = init_scale
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.scale_window = scale_window self.scale_window = scale_window
self.tolerance = tolerance
self._iter = 0 self._iter = 0
self._last_overflow_iter = -1 self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0
def update_scale(self, overflow): def update_scale(self, overflow):
iter_since_rescale = self._iter - self._last_rescale_iter
if overflow: if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self.loss_scale /= self.scale_factor
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1 self._iter += 1
@staticmethod @staticmethod
...@@ -55,6 +65,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -55,6 +65,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
self.scaler = DynamicLossScaler( self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale, init_scale=args.fp16_init_scale,
scale_window=scale_window, scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
) )
@staticmethod @staticmethod
......
...@@ -133,6 +133,8 @@ def get_parser(desc, default_task='translation'): ...@@ -133,6 +133,8 @@ def get_parser(desc, default_task='translation'):
help='default FP16 loss scale') help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int, parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument('--task', metavar='TASK', default=default_task, parser.add_argument('--task', metavar='TASK', default=default_task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment