Commit 8ae63102 authored by Michael Carilli's avatar Michael Carilli
Browse files

Progress towards materialize_master_grads=False

parent c763f0fe
......@@ -90,7 +90,7 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads):
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
......@@ -111,13 +111,15 @@ def post_backward_models_are_masters(scaler, params, stashed_grads):
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
models_are_masters=True,
scale_override=scale_override)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
grads_needing_unscale_with_stash,
scale_override=scale_override)
# Clear the stash.
for i in range(len(stashed_grads)):
......@@ -326,12 +328,26 @@ def post_backward_with_master_weights_FusedSGD(self, scaler):
self._amp_lazy_init()
current_scale = scaler.loss_scale()
out_scale = current_scale
if self.scale_set_by_backward:
out_scale = min(current_scale, self.most_recent_scale)
scale_adjustment = out_scale/current_scale
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# Grads created by this backward pass have been scaled by current_scale.
# unscale() implements grads*1/scale, so "scale" should be current_scale/out_scale
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
self.most_recent_scale = out_scale
self.scale_set_by_backward = True
def prepare_backward_no_master_weights_FusedSGD(self):
prepare_backward_no_master_weights(self)
......
......@@ -89,11 +89,13 @@ class LossScaler(object):
break
# unused_scale keeps some of the old API alive for hopefully a short time.
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False):
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
if self._has_overflow:
return
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if scale == 1.0 and models_are_masters and not self.dynamic:
return
......@@ -146,11 +148,14 @@ class LossScaler(object):
def unscale_with_stashed(self,
model_grads,
stashed_master_grads,
master_grads):
master_grads,
scale_override=None):
if self._has_overflow:
return
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad
......
......@@ -68,7 +68,8 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
self.scale = 1.0
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
......@@ -184,6 +185,9 @@ class FusedSGD(Optimizer):
nesterov,
first_run,
self.wd_after_momentum,
self.scale)
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment