Commit fc6c5a25 authored by Michael Carilli's avatar Michael Carilli
Browse files

some cleanup

parent 683b6e0e
...@@ -107,22 +107,6 @@ def check_optimizers(optimizers): ...@@ -107,22 +107,6 @@ def check_optimizers(optimizers):
"on the specified opt_level (and optional overridden properties).") "on the specified opt_level (and optional overridden properties).")
def wrap_fused_adam(optimizer, properties):
msg = 'Currently, the usage of FusedAdam is restricted to '\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert properties.master_weights is True, msg
assert properties.cast_model_type is torch.float16, msg
assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties, num_losses=1): def _initialize(models, optimizers, properties, num_losses=1):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
...@@ -184,10 +168,6 @@ def _initialize(models, optimizers, properties, num_losses=1): ...@@ -184,10 +168,6 @@ def _initialize(models, optimizers, properties, num_losses=1):
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
for i, optimizer in enumerate(optimizers): for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties) optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = [] _amp_state.loss_scalers = []
......
...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params ...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print from ._amp_state import maybe_print
import torch import torch
from ..optimizers import FusedAdam
class AmpOptimizerState(object): class AmpOptimizerState(object):
...@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self): ...@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict()) self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads):
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def prepare_backward_with_master_weights(self): def prepare_backward_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
...@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler): ...@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads) preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case. # fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = [] post_backward_models_are_masters(
grads_needing_unscale_with_stash = [] scaler,
stashed = [] stash.all_fp32_from_fp32_params,
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
def lazy_init_no_master_weights(self): def lazy_init_no_master_weights(self):
...@@ -206,37 +214,7 @@ def post_backward_no_master_weights(self, scaler): ...@@ -206,37 +214,7 @@ def post_backward_no_master_weights(self, scaler):
(stash.all_fp32_params, stash.all_fp32_grad_stash)) (stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
# This is a lot of python overhead... post_backward_models_are_masters(scaler, params, stashed_grads)
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
...@@ -283,6 +261,7 @@ def _process_optimizer(optimizer, properties): ...@@ -283,6 +261,7 @@ def _process_optimizer(optimizer, properties):
optimizer._master_params_to_model_params = types.MethodType( optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer) _master_params_to_model_params, optimizer)
if not isinstance(optimizer, FusedAdam):
old_step = optimizer.step old_step = optimizer.step
def new_step(self): def new_step(self):
retval = old_step() retval = old_step()
...@@ -313,18 +292,28 @@ def _process_optimizer(optimizer, properties): ...@@ -313,18 +292,28 @@ def _process_optimizer(optimizer, properties):
param.grad = None param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer) optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer) prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer) post_backward_with_master_weights, optimizer)
else: else:
optimizer._lazy_init_maybe_master_weights = types.MethodType( optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer) lazy_init_no_master_weights, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType( optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer) prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer) post_backward_no_master_weights, optimizer)
......
...@@ -6,8 +6,6 @@ from . import utils ...@@ -6,8 +6,6 @@ from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...@@ -82,11 +80,6 @@ def scale_loss(loss, ...@@ -82,11 +80,6 @@ def scale_loss(loss,
if isinstance(optimizers, torch.optim.Optimizer): if isinstance(optimizers, torch.optim.Optimizer):
optimizers = [optimizers] optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizers.cur_scale
else:
loss_scaler = _amp_state.loss_scalers[loss_id] loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale() loss_scale = loss_scaler.loss_scale()
...@@ -113,8 +106,8 @@ def scale_loss(loss, ...@@ -113,8 +106,8 @@ def scale_loss(loss,
for optimizer in optimizers: for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True optimizer._amp_stash.params_have_scaled_gradients = True
else: else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods. # FusedAdam and FusedSGD may take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused): # if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state() loss_scaler.clear_overflow_state()
for optimizer in optimizers: for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler) optimizer._post_amp_backward(loss_scaler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment