"tests/vscode:/vscode.git/clone" did not exist on "916d375ba3d62e018231633ca74e33ce128085c3"
Commit e0f2ffa5 authored by Michael Carilli's avatar Michael Carilli
Browse files

Initial organization

parent bf4aa847
...@@ -3,7 +3,7 @@ from ..fp16_utils import master_params_to_model_params ...@@ -3,7 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print from ._amp_state import maybe_print
import torch import torch
from ..optimizers import FusedAdam from ..optimizers import FusedAdam, FusedSGD
class AmpOptimizerState(object): class AmpOptimizerState(object):
...@@ -217,7 +217,11 @@ def post_backward_no_master_weights(self, scaler): ...@@ -217,7 +217,11 @@ def post_backward_no_master_weights(self, scaler):
post_backward_models_are_masters(scaler, params, stashed_grads) post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_with_master_weights_fused(self): #####################################################################################
# FusedAdam versions
#####################################################################################
def prepare_backward_with_master_weights_FusedAdam(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: if not stash.lazy_init_called:
...@@ -225,7 +229,7 @@ def prepare_backward_with_master_weights_fused(self): ...@@ -225,7 +229,7 @@ def prepare_backward_with_master_weights_fused(self):
stash.lazy_init_called = True stash.lazy_init_called = True
def post_backward_with_master_weights_fused(self, scaler): def post_backward_with_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash stash = self._amp_stash
stash.scale = scaler.loss_scale() stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups] stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
...@@ -250,7 +254,7 @@ def post_backward_with_master_weights_fused(self, scaler): ...@@ -250,7 +254,7 @@ def post_backward_with_master_weights_fused(self, scaler):
stash.grad_norms = norm_groups stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self): def prepare_backward_no_master_weights_FusedAdam(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: if not stash.lazy_init_called:
...@@ -258,7 +262,63 @@ def prepare_backward_no_master_weights_fused(self): ...@@ -258,7 +262,63 @@ def prepare_backward_no_master_weights_fused(self):
stash.lazy_init_called = True stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler): def post_backward_no_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
#####################################################################################
# FusedSGD versions
# Eat this ugly code duplication for now. First make it work, then make it clean.
# It's difficult to anticipate what can be unified between the FusedAdam and FusedSGD
# implementations until I have them both working.
#####################################################################################
def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_FusedSGD(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash stash = self._amp_stash
stash.scale = scaler.loss_scale() stash.scale = scaler.loss_scale()
stash.grads = None stash.grads = None
...@@ -314,7 +374,7 @@ def _process_optimizer(optimizer, properties): ...@@ -314,7 +374,7 @@ def _process_optimizer(optimizer, properties):
old_step = optimizer.step old_step = optimizer.step
def new_step(self): def new_step(self):
retval = old_step() retval = old_step()
if not isinstance(self, FusedAdam): if not (isinstance(self, FusedAdam) or isinstance(self, FusedSGD)):
self._master_params_to_model_params() self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params: for param in self._amp_stash.all_fp32_from_fp16_params:
...@@ -344,9 +404,14 @@ def _process_optimizer(optimizer, properties): ...@@ -344,9 +404,14 @@ def _process_optimizer(optimizer, properties):
if isinstance(optimizer, FusedAdam): if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_fused, optimizer) prepare_backward_with_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_fused, optimizer) post_backward_with_master_weights_FusedSGD, optimizer)
else: else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer) prepare_backward_with_master_weights, optimizer)
...@@ -358,9 +423,14 @@ def _process_optimizer(optimizer, properties): ...@@ -358,9 +423,14 @@ def _process_optimizer(optimizer, properties):
if isinstance(optimizer, FusedAdam): if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_fused, optimizer) prepare_backward_no_master_weights_FusedAdam, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_FusedAdam, optimizer)
elif isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_fused, optimizer) post_backward_no_master_weights_FusedSGD, optimizer)
else: else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer) prepare_backward_no_master_weights, optimizer)
......
...@@ -73,7 +73,7 @@ class FusedSGD(Optimizer): ...@@ -73,7 +73,7 @@ class FusedSGD(Optimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else: else:
raise RuntimeError('apex.optim.SGD requires cuda extensions') raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state): def __setstate__(self, state):
super(SGD, self).__setstate__(state) super(SGD, self).__setstate__(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment