Commit 61b8a0fd authored by Michael Carilli's avatar Michael Carilli
Browse files

Rough cut, control flow should work for scaleout testing

parent dda59354
...@@ -217,6 +217,55 @@ def post_backward_no_master_weights(self, scaler): ...@@ -217,6 +217,55 @@ def post_backward_no_master_weights(self, scaler):
post_backward_models_are_masters(scaler, params, stashed_grads) post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in self.fp16_groups]
stash.output_params = [[param for param in in group] for group in self.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
self._amp_stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
stash = self._amp_stash stash = self._amp_stash
if multi_tensor_applier.available: if multi_tensor_applier.available:
...@@ -252,6 +301,7 @@ def _process_optimizer(optimizer, properties): ...@@ -252,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]); optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights: if properties.master_weights:
...@@ -261,16 +311,16 @@ def _process_optimizer(optimizer, properties): ...@@ -261,16 +311,16 @@ def _process_optimizer(optimizer, properties):
optimizer._master_params_to_model_params = types.MethodType( optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer) _master_params_to_model_params, optimizer)
if not isinstance(optimizer, FusedAdam): old_step = optimizer.step
old_step = optimizer.step def new_step(self):
def new_step(self): retval = old_step()
retval = old_step() if not isinstance(self, FusedAdam):
self._master_params_to_model_params() self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params: for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None param.grad = None
return retval return retval
optimizer.step = types.MethodType(new_step, optimizer) optimizer.step = types.MethodType(new_step, optimizer)
old_zero_grad = optimizer.zero_grad old_zero_grad = optimizer.zero_grad
def new_zero_grad(self): def new_zero_grad(self):
......
...@@ -78,6 +78,11 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -78,6 +78,11 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale
grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment