Commit 8ae63102 authored by Michael Carilli's avatar Michael Carilli
Browse files

Progress towards materialize_master_grads=False

parent c763f0fe
...@@ -90,7 +90,7 @@ def lazy_init_with_master_weights(self): ...@@ -90,7 +90,7 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict()) self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads): def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
# This is a lot of python overhead... # This is a lot of python overhead...
grads_needing_unscale = [] grads_needing_unscale = []
grads_needing_unscale_with_stash = [] grads_needing_unscale_with_stash = []
...@@ -111,13 +111,15 @@ def post_backward_models_are_masters(scaler, params, stashed_grads): ...@@ -111,13 +111,15 @@ def post_backward_models_are_masters(scaler, params, stashed_grads):
grads_needing_unscale, grads_needing_unscale,
grads_needing_unscale, grads_needing_unscale,
scaler.loss_scale(), scaler.loss_scale(),
models_are_masters=True) models_are_masters=True,
scale_override=scale_override)
if len(grads_needing_unscale_with_stash) > 0: if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed( scaler.unscale_with_stashed(
grads_needing_unscale_with_stash, grads_needing_unscale_with_stash,
stashed, stashed,
grads_needing_unscale_with_stash) grads_needing_unscale_with_stash,
scale_override=scale_override)
# Clear the stash. # Clear the stash.
for i in range(len(stashed_grads)): for i in range(len(stashed_grads)):
...@@ -326,12 +328,26 @@ def post_backward_with_master_weights_FusedSGD(self, scaler): ...@@ -326,12 +328,26 @@ def post_backward_with_master_weights_FusedSGD(self, scaler):
self._amp_lazy_init() self._amp_lazy_init()
current_scale = scaler.loss_scale()
out_scale = current_scale
if self.scale_set_by_backward:
out_scale = min(current_scale, self.most_recent_scale)
scale_adjustment = out_scale/current_scale
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)) (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# Grads created by this backward pass have been scaled by current_scale.
# unscale() implements grads*1/scale, so "scale" should be current_scale/out_scale
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads) post_backward_models_are_masters(scaler, params, stashed_grads)
self.most_recent_scale = out_scale
self.scale_set_by_backward = True
def prepare_backward_no_master_weights_FusedSGD(self): def prepare_backward_no_master_weights_FusedSGD(self):
prepare_backward_no_master_weights(self) prepare_backward_no_master_weights(self)
......
...@@ -89,11 +89,13 @@ class LossScaler(object): ...@@ -89,11 +89,13 @@ class LossScaler(object):
break break
# unused_scale keeps some of the old API alive for hopefully a short time. # unused_scale keeps some of the old API alive for hopefully a short time.
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False): def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
if self._has_overflow: if self._has_overflow:
return return
scale = self._loss_scale scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if scale == 1.0 and models_are_masters and not self.dynamic: if scale == 1.0 and models_are_masters and not self.dynamic:
return return
...@@ -146,11 +148,14 @@ class LossScaler(object): ...@@ -146,11 +148,14 @@ class LossScaler(object):
def unscale_with_stashed(self, def unscale_with_stashed(self,
model_grads, model_grads,
stashed_master_grads, stashed_master_grads,
master_grads): master_grads,
scale_override=None):
if self._has_overflow: if self._has_overflow:
return return
scale = self._loss_scale scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad if (not LossScaler.warned_unscaling_non_fp32_grad
......
...@@ -68,7 +68,8 @@ class FusedSGD(Optimizer): ...@@ -68,7 +68,8 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
self.scale = 1.0 self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
...@@ -184,6 +185,9 @@ class FusedSGD(Optimizer): ...@@ -184,6 +185,9 @@ class FusedSGD(Optimizer):
nesterov, nesterov,
first_run, first_run,
self.wd_after_momentum, self.wd_after_momentum,
self.scale) 1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment