"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e843d73fc73843fc6eaed11e3387968fb0f79538"
Commit 03ef3ab8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --fp16-scale-tolerance (#397)

Summary:
Let's only decrease the loss scale if a large enough percentage of batches overflow.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/397

Differential Revision: D13355159

Pulled By: myleott

fbshipit-source-id: e17dde73d34a639519b4348c013fdd19d2b314e6
parent 6c006a34
......@@ -12,19 +12,29 @@ from fairseq import optim, utils
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0
def update_scale(self, overflow):
iter_since_rescale = self._iter - self._last_rescale_iter
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self.loss_scale /= self.scale_factor
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1
@staticmethod
......@@ -55,6 +65,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)
@staticmethod
......
......@@ -133,6 +133,8 @@ def get_parser(desc, default_task='translation'):
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale')
# Task definitions can be found under fairseq/tasks/
parser.add_argument('--task', metavar='TASK', default=default_task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment