Commit e0f2ffa5 authored by Michael Carilli's avatar Michael Carilli
Browse files

Initial organization

parent bf4aa847
......@@ -3,7 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
import torch
from ..optimizers import FusedAdam
from ..optimizers import FusedAdam, FusedSGD
class AmpOptimizerState(object):
......@@ -217,7 +217,11 @@ def post_backward_no_master_weights(self, scaler):
post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_with_master_weights_fused(self):
#####################################################################################
# FusedAdam versions
#####################################################################################
def prepare_backward_with_master_weights_FusedAdam(self):
stash = self._amp_stash
if not stash.lazy_init_called:
......@@ -225,7 +229,7 @@ def prepare_backward_with_master_weights_fused(self):
stash.lazy_init_called = True
def post_backward_with_master_weights_fused(self, scaler):
def post_backward_with_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
......@@ -250,7 +254,7 @@ def post_backward_with_master_weights_fused(self, scaler):
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self):
def prepare_backward_no_master_weights_FusedAdam(self):
stash = self._amp_stash
if not stash.lazy_init_called:
......@@ -258,7 +262,63 @@ def prepare_backward_no_master_weights_fused(self):
stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler):
def post_backward_no_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
#####################################################################################
# FusedSGD versions
# Eat this ugly code duplication for now. First make it work, then make it clean.
# It's difficult to anticipate what can be unified between the FusedAdam and FusedSGD
# implementations until I have them both working.
#####################################################################################
def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_FusedSGD(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
......@@ -314,7 +374,7 @@ def _process_optimizer(optimizer, properties):
old_step = optimizer.step
def new_step(self):
retval = old_step()
if not isinstance(self, FusedAdam):
if not (isinstance(self, FusedAdam) or isinstance(self, FusedSGD)):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
......@@ -344,9 +404,14 @@ def _process_optimizer(optimizer, properties):
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_fused, optimizer)
prepare_backward_with_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_fused, optimizer)
post_backward_with_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
......@@ -358,9 +423,14 @@ def _process_optimizer(optimizer, properties):
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_fused, optimizer)
prepare_backward_no_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_fused, optimizer)
post_backward_no_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
......
......@@ -73,7 +73,7 @@ class FusedSGD(Optimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optim.SGD requires cuda extensions')
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(SGD, self).__setstate__(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment