Commit 61b8a0fd authored by Michael Carilli's avatar Michael Carilli
Browse files

Rough cut, control flow should work for scaleout testing

parent dda59354
......@@ -217,6 +217,55 @@ def post_backward_no_master_weights(self, scaler):
post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in self.fp16_groups]
stash.output_params = [[param for param in in group] for group in self.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
self._amp_stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
......@@ -252,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights:
......@@ -261,16 +311,16 @@ def _process_optimizer(optimizer, properties):
optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer)
if not isinstance(optimizer, FusedAdam):
old_step = optimizer.step
def new_step(self):
retval = old_step()
old_step = optimizer.step
def new_step(self):
retval = old_step()
if not isinstance(self, FusedAdam):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
return retval
optimizer.step = types.MethodType(new_step, optimizer)
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
return retval
optimizer.step = types.MethodType(new_step, optimizer)
old_zero_grad = optimizer.zero_grad
def new_zero_grad(self):
......
......@@ -78,6 +78,11 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale
grad_norms = self._amp_stash.grad_norms
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment