Commit 00dbe4b4 authored by Michael Carilli's avatar Michael Carilli
Browse files

test_fused_sgd.py passing

parent 72bce160
...@@ -75,6 +75,8 @@ def lazy_init_with_master_weights(self): ...@@ -75,6 +75,8 @@ def lazy_init_with_master_weights(self):
for group in stash.fp32_from_fp32_groups: for group in stash.fp32_from_fp32_groups:
stash.all_fp32_from_fp32_params += group stash.all_fp32_from_fp32_params += group
# all_fp16_grad_stash is only needed for fused optimizers.
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params] # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params] stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
...@@ -125,9 +127,7 @@ def post_backward_models_are_masters(scaler, params, stashed_grads): ...@@ -125,9 +127,7 @@ def post_backward_models_are_masters(scaler, params, stashed_grads):
def prepare_backward_with_master_weights(self): def prepare_backward_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision: # Set up to leverage grad copy elision:
...@@ -145,6 +145,8 @@ def prepare_backward_with_master_weights(self): ...@@ -145,6 +145,8 @@ def prepare_backward_with_master_weights(self):
def post_backward_with_master_weights(self, scaler): def post_backward_with_master_weights(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
# This is a lot of python overhead... # This is a lot of python overhead...
fp16_grads_needing_unscale = [] fp16_grads_needing_unscale = []
new_fp32_grads = [] new_fp32_grads = []
...@@ -206,9 +208,7 @@ def lazy_init_no_master_weights(self): ...@@ -206,9 +208,7 @@ def lazy_init_no_master_weights(self):
def prepare_backward_no_master_weights(self): def prepare_backward_no_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad stash.all_fp16_grad_stash[i] = param.grad
...@@ -224,6 +224,8 @@ def prepare_backward_no_master_weights(self): ...@@ -224,6 +224,8 @@ def prepare_backward_no_master_weights(self):
def post_backward_no_master_weights(self, scaler): def post_backward_no_master_weights(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_params, stash.all_fp32_grad_stash)) (stash.all_fp32_params, stash.all_fp32_grad_stash))
...@@ -238,13 +240,14 @@ def post_backward_no_master_weights(self, scaler): ...@@ -238,13 +240,14 @@ def post_backward_no_master_weights(self, scaler):
def prepare_backward_with_master_weights_FusedAdam(self): def prepare_backward_with_master_weights_FusedAdam(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_with_master_weights_FusedAdam(self, scaler): def post_backward_with_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
stash.scale = scaler.loss_scale() stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups] stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups] stash.output_params = [[param for param in group] for group in stash.fp16_groups]
...@@ -271,13 +274,14 @@ def post_backward_with_master_weights_FusedAdam(self, scaler): ...@@ -271,13 +274,14 @@ def post_backward_with_master_weights_FusedAdam(self, scaler):
def prepare_backward_no_master_weights_FusedAdam(self): def prepare_backward_no_master_weights_FusedAdam(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_FusedAdam(self, scaler): def post_backward_no_master_weights_FusedAdam(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
stash.scale = scaler.loss_scale() stash.scale = scaler.loss_scale()
stash.grads = None stash.grads = None
stash.output_params = None stash.output_params = None
...@@ -296,9 +300,7 @@ def post_backward_no_master_weights_FusedAdam(self, scaler): ...@@ -296,9 +300,7 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
def prepare_backward_with_master_weights_FusedSGD(self): def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad stash.all_fp16_grad_stash[i] = param.grad
...@@ -314,6 +316,8 @@ def prepare_backward_with_master_weights_FusedSGD(self): ...@@ -314,6 +316,8 @@ def prepare_backward_with_master_weights_FusedSGD(self):
def post_backward_with_master_weights_FusedSGD(self, scaler): def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)) (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
...@@ -329,6 +333,14 @@ def post_backward_no_master_weights_FusedSGD(self, scaler): ...@@ -329,6 +333,14 @@ def post_backward_no_master_weights_FusedSGD(self, scaler):
post_backward_no_master_weights(self, scaler) post_backward_no_master_weights(self, scaler)
def _amp_lazy_init(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def _process_optimizer(optimizer, properties): def _process_optimizer(optimizer, properties):
if hasattr(optimizer, "_amp_stash"): if hasattr(optimizer, "_amp_stash"):
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.") raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
...@@ -342,7 +354,8 @@ def _process_optimizer(optimizer, properties): ...@@ -342,7 +354,8 @@ def _process_optimizer(optimizer, properties):
for name in ("_lazy_init_maybe_master_weights", for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params", "_master_params_to_model_params",
"_prepare_amp_backward", "_prepare_amp_backward",
"_post_amp_backward"): "_post_amp_backward",
"_amp_lazy_init"):
if hasattr(optimizer, name): if hasattr(optimizer, name):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name)) raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
...@@ -374,9 +387,7 @@ def _process_optimizer(optimizer, properties): ...@@ -374,9 +387,7 @@ def _process_optimizer(optimizer, properties):
old_zero_grad = optimizer.zero_grad old_zero_grad = optimizer.zero_grad
def new_zero_grad(self): def new_zero_grad(self):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called: self._amp_lazy_init()
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Zero the model grads. # Zero the model grads.
for param in stash.all_fp16_params: for param in stash.all_fp16_params:
if param.grad is not None: if param.grad is not None:
...@@ -426,4 +437,6 @@ def _process_optimizer(optimizer, properties): ...@@ -426,4 +437,6 @@ def _process_optimizer(optimizer, properties):
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer) post_backward_no_master_weights, optimizer)
optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
return optimizer return optimizer
...@@ -76,7 +76,7 @@ class FusedSGD(Optimizer): ...@@ -76,7 +76,7 @@ class FusedSGD(Optimizer):
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state): def __setstate__(self, state):
super(SGD, self).__setstate__(state) super(FusedSGD, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment