Unverified Commit 3f87614f authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

WIP: Handle arbitrary combinations of optimizers/models/losses (#232)

* Refactor to allow more flexible treatment of multiple optimizers/models/losses

* Adding _process_optimizers.py

* Created L0 tests (now passing).

* fix: minor print typo (#234)

* make L1 results easier to read

* L0 multiple model/optimizer/loss test fleshed out

* Adding test that master params remain synced across distributed processes

* Docstring updates

* Docstring updates
parent 214fda42
...@@ -2,4 +2,4 @@ from .amp import init, half_function, float_function, promote_function,\ ...@@ -2,4 +2,4 @@ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts from .handle import scale_loss, disable_casts
from .frontend import initialize from .frontend import initialize
from ._amp_state import master_params from ._amp_state import master_params, _amp_state
...@@ -17,6 +17,7 @@ else: ...@@ -17,6 +17,7 @@ else:
class AmpState(object): class AmpState(object):
def __init__(self): def __init__(self):
self.hard_override=False self.hard_override=False
self.allow_incoming_model_not_fp32 = False
self.verbosity=1 self.verbosity=1
......
...@@ -6,6 +6,7 @@ import warnings ...@@ -6,6 +6,7 @@ import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
...@@ -122,7 +123,7 @@ def wrap_fused_adam(optimizer, properties): ...@@ -122,7 +123,7 @@ def wrap_fused_adam(optimizer, properties):
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale) return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties): def _initialize(models, optimizers, properties, num_losses=1):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
...@@ -146,6 +147,7 @@ def _initialize(models, optimizers, properties): ...@@ -146,6 +147,7 @@ def _initialize(models, optimizers, properties):
check_models(models) check_models(models)
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
check_optimizers(optimizers) check_optimizers(optimizers)
...@@ -181,21 +183,16 @@ def _initialize(models, optimizers, properties): ...@@ -181,21 +183,16 @@ def _initialize(models, optimizers, properties):
for optimizer in optimizers: for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
if properties.master_weights:
for i, optimizer in enumerate(optimizers): for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam): if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties) optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True,
verbose=False)
else:
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale,
verbose=False)
else: else:
for optimizer in optimizers: optimizers[i] = _process_optimizer(optimizer, properties)
optimizer.loss_scaler = LossScaler(properties.loss_scale)
_amp_state.loss_scalers = []
for _ in range(num_losses):
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale))
if properties.patch_torch_functions: if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway. # handle is unused here. It's accessible later through a global value anyway.
......
import types
from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
import torch
class AmpOptimizerState(object):
def __init__(self):
pass
def lazy_init_with_master_weights(self):
stash = self._amp_stash
stash.fp16_groups = []
stash.fp32_from_fp16_groups = []
stash.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.param_groups):
# maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.state:
self.state[master_param] = self.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
# maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
# .format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
stash.all_fp16_params = []
for group in stash.fp16_groups:
stash.all_fp16_params += group
stash.all_fp32_from_fp16_params = []
for group in stash.fp32_from_fp16_groups:
stash.all_fp32_from_fp16_params += group
stash.all_fp32_from_fp32_params = []
for group in stash.fp32_from_fp32_groups:
stash.all_fp32_from_fp32_params += group
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
for param in stash.all_fp32_from_fp16_params:
param.grad = None
for param in stash.all_fp32_from_fp32_params:
param.grad = None
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.load_state_dict(self.state_dict())
def prepare_backward_with_master_weights(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision:
param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# stash.all_fp32_from_fp16_grad_stash[i] = param.grad
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights(self, scaler):
stash = self._amp_stash
# This is a lot of python overhead...
fp16_grads_needing_unscale = []
new_fp32_grads = []
fp16_grads_needing_unscale_with_stash = []
preexisting_fp32_grads = []
for fp16_param, fp32_param in zip(stash.all_fp16_params,
stash.all_fp32_from_fp16_params):
if fp16_param.grad is None and fp32_param.grad is not None:
continue
elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None:
fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
preexisting_fp32_grads.append(fp32_param.grad)
else: # fp16_param.grad is None and fp32_param.grad is None:
continue
if len(fp16_grads_needing_unscale) > 0:
scaler.unscale(
fp16_grads_needing_unscale,
new_fp32_grads,
scaler.loss_scale(),
models_are_masters=False)
if len(fp16_grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
fp16_grads_needing_unscale_with_stash,
preexisting_fp32_grads,
preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params,
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
def lazy_init_no_master_weights(self):
stash = self._amp_stash
stash.all_fp16_params = []
stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']):
if param.type() == 'torch.cuda.HalfTensor':
stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
def prepare_backward_no_master_weights(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_params):
stash.all_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_no_master_weights(self, scaler):
stash = self._amp_stash
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types:
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def _process_optimizer(optimizer, properties):
if hasattr(optimizer, "_amp_stash"):
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
else:
optimizer._amp_stash = AmpOptimizerState()
optimizer._amp_stash.lazy_init_called = False
optimizer._amp_stash.already_patched = False
for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params",
"_prepare_amp_backward",
"_post_amp_backward"):
if hasattr(optimizer, name):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_with_master_weights, optimizer)
optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer)
old_step = optimizer.step
def new_step(self):
retval = old_step()
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
return retval
optimizer.step = types.MethodType(new_step, optimizer)
old_zero_grad = optimizer.zero_grad
def new_zero_grad(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Zero the model grads.
for param in stash.all_fp16_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
for param in stash.all_fp32_from_fp32_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
# Clear the master grads that are independent of model grads
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
else:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
return optimizer
...@@ -53,6 +53,7 @@ class Properties(object): ...@@ -53,6 +53,7 @@ class Properties(object):
# print("setting {} {}".format(name, value)) # print("setting {} {}".format(name, value))
if name == "cast_model_type": if name == "cast_model_type":
if self.opt_level == "O1" and value is not None: if self.opt_level == "O1" and value is not None:
if value is not False:
if value is not torch.float32: if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than " warn_or_err("O1 inserts casts around Torch functions rather than "
"model weights, so with O1, the model weights themselves " "model weights, so with O1, the model weights themselves "
...@@ -200,20 +201,28 @@ def initialize( ...@@ -200,20 +201,28 @@ def initialize(
keep_batchnorm_fp32=None, keep_batchnorm_fp32=None,
master_weights=None, master_weights=None,
loss_scale=None, loss_scale=None,
num_losses=1,
verbosity=1, verbosity=1,
): ):
""" """
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any. chosen ``opt_level`` and overridden properties, if any.
``amp.initialize`` must be called **after** you have finished constructing your model(s) and ``amp.initialize`` should be called **after** you have finished
constructing your model(s) and
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper. optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
See `Distributed training`_ in the Imagenet example. See `Distributed training`_ in the Imagenet example.
Currently, ``amp.initialize`` should only be called **once**,
although it can process an arbitrary number of
models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
If you think your use case requires ``amp.initialize`` to be called more than once,
`let us know`_.
Any property keyword argument that is not ``None`` will be interpreted as a manual override. Any property keyword argument that is not ``None`` will be interpreted as a manual override.
To prevent having to rewrite anything else in your script, name the returned models/optimizers To prevent having to rewrite anything else in your script, name the returned models/optimizers
to replace the passed models/optimizers, as in the Usage below. to replace the passed models/optimizers, as in the code sample below.
Args: Args:
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast. models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
...@@ -229,8 +238,15 @@ def initialize( ...@@ -229,8 +238,15 @@ def initialize(
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False". passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override. master_weights (bool, optional, default=None): Optional property override.
loss_scale (float or str, default=None): Optional property override. If passed as a string, loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic". must be a string representing a number, e.g., "128.0", or the string "dynamic".
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
which can improve stability. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
support multiple losses/backward passes, but use a single global loss scale
for all of them.
verbosity (int, default=1): Set to 0 to suppress Amp-related output. verbosity (int, default=1): Set to 0 to suppress Amp-related output.
Returns: Returns:
...@@ -238,7 +254,7 @@ def initialize( ...@@ -238,7 +254,7 @@ def initialize(
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
also be a list. also be a list.
Usage:: Permissible invocations::
model, optim = amp.initialize(model, optim,...) model, optim = amp.initialize(model, optim,...)
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...) model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
...@@ -268,6 +284,15 @@ def initialize( ...@@ -268,6 +284,15 @@ def initialize(
.. _`Imagenet example`: .. _`Imagenet example`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet https://github.com/NVIDIA/apex/tree/master/examples/imagenet
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
.. _`Advanced Amp Usage topic`:
https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
.. _`let us know`:
https://github.com/NVIDIA/apex/issues
""" """
_amp_state.opt_properties = Properties() _amp_state.opt_properties = Properties()
_amp_state.verbosity = verbosity _amp_state.verbosity = verbosity
...@@ -308,7 +333,7 @@ def initialize( ...@@ -308,7 +333,7 @@ def initialize(
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True) maybe_print("{:22} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties) return _initialize(models, optimizers, _amp_state.opt_properties, num_losses)
# TODO: is this necessary/useful? # TODO: is this necessary/useful?
......
...@@ -13,7 +13,8 @@ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused ...@@ -13,7 +13,8 @@ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager @contextlib.contextmanager
def scale_loss(loss, def scale_loss(loss,
optimizer, optimizers,
loss_id=0,
model=None, model=None,
delay_unscale=False): delay_unscale=False):
""" """
...@@ -44,18 +45,29 @@ def scale_loss(loss, ...@@ -44,18 +45,29 @@ def scale_loss(loss,
manager yields is simply ``loss.float()*loss_scale``, so in principle manager yields is simply ``loss.float()*loss_scale``, so in principle
``loss`` could have more than one element, as long as you call ``loss`` could have more than one element, as long as you call
``backward()`` on ``scaled_loss`` appropriately within the context manager body. ``backward()`` on ``scaled_loss`` appropriately within the context manager body.
optimizer: Must be an optimizer returned from an earlier call to ``amp.initialize``. optimizers: All optimizer(s) for which the current backward pass is creating gradients.
Must be an optimizer or list of optimizers returned from an earlier call
to ``amp.initialize``. For example use with multiple optimizers, see
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
being used for the current backward pass. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations. optimizations.
delay_unscale(bool, default=False): Don't unscale the gradients or perform model->master delay_unscale(bool, optional, default=False): ``delay_unscale`` is a ninja option that only
gradient copies on context manager exit. `Advanced Amp Usage`_ illustrates serves as a minor performance optimization, so only use it if you know what you're doing.
situations where this is necessary. If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit.
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
illustrates a situation where this CAN (but does not need to) be used.
.. warning:: .. warning::
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
called yet after context manager exit, and must wait for another, later backward context called yet after context manager exit, and must wait for another, later backward context
manager invocation with ``delay_unscale`` left to False. manager invocation with ``delay_unscale`` left to False.
See `Advanced Amp Usage`_ for examples.
.. _`Advanced Amp Usage`: .. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html https://nvidia.github.io/apex/advanced.html
...@@ -64,18 +76,19 @@ def scale_loss(loss, ...@@ -64,18 +76,19 @@ def scale_loss(loss,
yield loss yield loss
return return
if optimizer.loss_scaler is None: if isinstance(optimizers, torch.optim.Optimizer):
raise RuntimeError("optimizer passed to scale_loss does not have a loss_scaler.") optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API... # this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler. # TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizer, FP16_Optimizer_for_fused): if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizer.cur_scale loss_scale = optimizers.cur_scale
else: else:
loss_scale = optimizer.loss_scaler.loss_scale() loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights) if ((not _amp_state.opt_properties.master_weights)
and (not optimizer.loss_scaler.dynamic) and (not loss_scaler.dynamic)
and loss_scale == 1.0): and loss_scale == 1.0):
yield loss.float() yield loss.float()
# Needing to drop the cache here as well is an ugly gotcha. # Needing to drop the cache here as well is an ugly gotcha.
...@@ -85,34 +98,42 @@ def scale_loss(loss, ...@@ -85,34 +98,42 @@ def scale_loss(loss,
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
return return
if isinstance(optimizers, list):
for optimizer in optimizers:
optimizer._prepare_amp_backward()
yield (loss.float())*loss_scale yield (loss.float())*loss_scale
# this isn't pretty but it unifies things. Once I deprecate the old API entirely,
# I will have freedom to clean this up. Maybe instead of wrapping optimizers,
# I can simply construct a set of attributes (e.g. master params) and assign them
# directly to optimizer instances.
if not delay_unscale: if not delay_unscale:
# The FP16_Optimizer for FusedAdam will take care of unscaling as part of # FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
# its step() method. if not isinstance(optimizers, FP16_Optimizer_for_fused):
if not isinstance(optimizer, FP16_Optimizer_for_fused): loss_scaler.clear_overflow_state()
if isinstance(optimizer, FP16_Optimizer_general): for optimizer in optimizers:
optimizer.update_master_grads() optimizer._post_amp_backward(loss_scaler)
else:
optimizer.loss_scaler.clear_overflow_state()
optimizer.loss_scaler.unscale(
master_params(optimizer),
master_params(optimizer),
loss_scale)
# For future fused optimizers that enable sync-free dynamic loss scaling, # For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False. # should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale() should_skip = loss_scaler.update_scale()
if should_skip: if should_skip:
optimizer_step = optimizer.step for optimizer in optimizers:
if not optimizer._amp_stash.already_patched:
# Close on loss_scaler and loss_id as well, to be safe. Probably not
# necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step
def skip_step(): def skip_step():
maybe_print("Gradient overflow. Skipping step, reducing " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"loss scale to {}".format(optimizer.loss_scaler.loss_scale())) "{} reducing loss scale to {}").format(loss_id,
optimizer.step = optimizer_step loss_scaler.loss_scale()))
optimizer.step = skip_step if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None
opt.step = opt_step
opt._amp_stash.already_patched = False
return skip_step
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
optimizer._amp_stash.already_patched = True
# Probably ok to skip this if not delay_unscale # Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions: if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
...@@ -151,6 +172,10 @@ class AmpHandle(object): ...@@ -151,6 +172,10 @@ class AmpHandle(object):
@contextlib.contextmanager @contextlib.contextmanager
def scale_loss(self, loss, optimizer): def scale_loss(self, loss, optimizer):
raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
"documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
if not self.is_active(): if not self.is_active():
yield loss yield loss
return return
......
...@@ -3,7 +3,7 @@ from ..multi_tensor_apply import multi_tensor_applier ...@@ -3,7 +3,7 @@ from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product from itertools import product
def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False): def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
...@@ -16,6 +16,21 @@ def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=F ...@@ -16,6 +16,21 @@ def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=F
master_grad.mul_(scale) master_grad.mul_(scale)
return False return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.to(master_grad.dtype)
stashed_grad.add_(scale, converted_model_grad)
master_grad.data = stashed_grad.data
return False
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_unscaling_non_fp32_grad = False warned_unscaling_non_fp32_grad = False
...@@ -41,6 +56,7 @@ class LossScaler(object): ...@@ -41,6 +56,7 @@ class LossScaler(object):
import amp_C import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
else: else:
if not LossScaler.warned_no_fused_kernel: if not LossScaler.warned_no_fused_kernel:
maybe_print( maybe_print(
...@@ -55,108 +71,124 @@ class LossScaler(object): ...@@ -55,108 +71,124 @@ class LossScaler(object):
def loss_scale(self): def loss_scale(self):
return self._loss_scale return self._loss_scale
def unscale_grads_python(self, model_grads, master_grads, scale): def unscale_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads): for model, master in zip(model_grads, master_grads):
if model is not None: if model is not None:
if not LossScaler.warned_unscaling_non_fp32_grad: if not LossScaler.warned_unscaling_non_fp32_grad:
if master.type() != "torch.cuda.FloatTensor": if master.dtype != torch.float32:
maybe_print( maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) + "Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. " "Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python( self._has_overflow = scale_check_overflow_python(model,
model,
1./scale,
master, master,
1./scale,
self.dynamic) self.dynamic)
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
break break
def clear_overflow_state(self): # unused_scale keeps some of the old API alive for hopefully a short time.
self._has_overflow = False def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False):
if self.has_fused_kernel:
self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale):
if self._has_overflow: if self._has_overflow:
return return
# Lots of defensive list processing going on here. Way more less efficient than scale = self._loss_scale
# consuming the iterator directly. Need to examine Python overhead.
model_master_params = [(model, master) for model, master if scale == 1.0 and models_are_masters and not self.dynamic:
in zip(model_params, master_params)] # some of these may be None return
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
# TODO: Make these lists permanent attributes of self, so they don't need to be created # if (not LossScaler.warned_unscaling_non_fp32_grad
# or garbage collected. Profiler shows that garbage collection overhead may be # and master_grads[0].dtype == torch.float16):
# substantial (200-300 usec). # print("Warning: unscaling grads that are not FP32. "
# This may be tricky because right now the lists need to be packed densely. # "Unscaling non-fp32 grads may indicate an error. "
# Maybe this could be handled within the multi_tensor_apply wrapper # "When using Amp, you don't need to call .half() on your model.")
# (allow some Tensors to be None using at::optional). # # Setting this to True unconditionally allows the possibility of an escape
src_dst_pairs = {torch.float16 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}, # # if never-before-seen non-fp32 grads are created in some later iteration.
torch.float32 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}} # LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
for model, master in model_master_params: self._overflow_buf,
# Sync the None-ness of model and master params [model_grads, master_grads],
if model.grad is None and master.grad is not None: 1./scale)
master.grad = None else:
if model.grad is not None and master.grad is None: self.unscale_python(model_grads, master_grads, scale)
master.grad = torch.empty_like(master)
# Defer to update_scale
if model.grad is not None: # If the fused kernel is available, we only need one D2H memcopy and sync.
if model.grad is master.grad and scale == 1.0 and not self.dynamic: # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
# self._has_overflow = self._overflow_buf.item()
def unscale_with_stashed_python(self,
model_grads,
stashed_master_grads,
master_grads,
scale):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None:
continue continue
else: else:
src_dst_pairs[model.dtype][master.dtype][0].append(model.grad.data) if not LossScaler.warned_unscaling_non_fp32_grad:
src_dst_pairs[model.dtype][master.dtype][1].append(master.grad.data) if master.dtype != torch.float32:
maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = axpby_check_overflow_python(model,
stashed,
master,
1./scale,
self.dynamic)
if self._has_overflow and self.dynamic:
break
def unscale_with_stashed(self,
model_grads,
stashed_master_grads,
master_grads):
if self._has_overflow:
return
assert len(src_dst_pairs[torch.float32][torch.float16][0]) == 0, "The loss scaler is "\ scale = self._loss_scale
"being asked to unscale FP32 model gradients into FP16 master gradients. This is "\
"almost certainly an error."
for src, dst in product((torch.float16, torch.float32), if LossScaler.has_fused_kernel:
(torch.float16, torch.float32)): if (not LossScaler.warned_unscaling_non_fp32_grad
if len(src_dst_pairs[src][dst][0]) > 0: and master_grads[0].dtype == torch.float16):
if not LossScaler.warned_unscaling_non_fp32_grad and dst is torch.float16:
print("Warning: unscaling grads that are not FP32. " print("Warning: unscaling grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. " "Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
# Setting this to True unconditionally allows the possibility of an escape # Setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration. # if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier( multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
LossScaler.multi_tensor_scale_cuda,
self._overflow_buf, self._overflow_buf,
src_dst_pairs[src][dst], [model_grads, stashed_master_grads, master_grads],
1./scale) 1./scale,
1.0,
0) # check only arg 0, aka the incoming model grads, for infs
else: else:
# Sync the None-ness of model and master params. self.unscale_with_stashed_python(model_grads,
all_same = True stashed_master_grads,
for model, master in model_master_params: master_grads,
if model.grad is None and master.grad is not None: scale)
master.grad = None
if model.grad is not None and master.grad is None:
master.grad = torch.empty_like(master)
if model.grad is not master.grad:
all_same = False
if scale == 1.0 and all_same and not self.dynamic:
return
# TODO: Make these lists permanent attributes of self, so they don't need to be created # Defer to update_scale
# or garbage collected? # If the fused kernel is available, we only need one D2H memcopy and sync.
model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None] # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None] # self._has_overflow = self._overflow_buf.item()
self.unscale_grads_python(model_grads, master_grads, scale) def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() self._has_overflow = self._overflow_buf.item()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
should_skip = True should_skip = True
self._loss_scale /= 2. self._loss_scale /= 2.
......
...@@ -540,18 +540,37 @@ class FP16_Optimizer(object): ...@@ -540,18 +540,37 @@ class FP16_Optimizer(object):
if len(self.all_fp16_params) > 0: if len(self.all_fp16_params) > 0:
# print("Model grads before") # print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params]) # print([param.grad.data for param in self.all_fp16_params])
# I'm ONLY writing this as an incremental way to make some tests pass until
# I can refactor the tests as well.
# FP16_Optimizer should not be used by anyone.
model_grads = []
master_grads = []
for model_param, master_param in zip(self.all_fp16_params,
self.all_fp32_from_fp16_params):
if model_param.grad is not None:
model_grads.append(model_param.grad)
if master_param.grad is None:
master_param.grad = torch.empty_like(master_param)
master_grads.append(master_param.grad)
self.loss_scaler.unscale( self.loss_scaler.unscale(
self.all_fp16_params, model_grads,
self.all_fp32_from_fp16_params, master_grads,
self.loss_scaler.loss_scale()) self.loss_scaler.loss_scale())
# print("Master grads after") # print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp16_params]) # print([param.grad.data for param in self.all_fp32_from_fp16_params])
if len(self.all_fp32_from_fp32_params) > 0: if len(self.all_fp32_from_fp32_params) > 0:
model_grads = []
master_grads = []
for model_param, master_param in zip(self.all_fp32_from_fp32_params,
self.all_fp32_from_fp32_params):
if model_param.grad is not None:
model_grads.append(model_param.grad)
master_grads.append(master_param.grad)
# print("Model grads before") # print("Model grads before")
# print([param.grad.data for param in self.all_fp32_from_fp32_params]) # print([param.grad.data for param in self.all_fp32_from_fp32_params])
self.loss_scaler.unscale( self.loss_scaler.unscale(
self.all_fp32_from_fp32_params, model_grads,
self.all_fp32_from_fp32_params, master_grads,
self.loss_scaler.loss_scale()) self.loss_scaler.loss_scale())
# print("Master grads after") # print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp32_params]) # print([param.grad.data for param in self.all_fp32_from_fp32_params])
......
...@@ -11,7 +11,8 @@ void multi_tensor_axpby_cuda( ...@@ -11,7 +11,8 @@ void multi_tensor_axpby_cuda(
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float a, float a,
float b); float b,
int arg_to_check);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
...@@ -21,7 +21,8 @@ struct AxpbyFunctor ...@@ -21,7 +21,8 @@ struct AxpbyFunctor
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<3>& tl, TensorListMetadata<3>& tl,
float a, float a,
float b) float b,
int arg_to_check)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -68,11 +69,16 @@ struct AxpbyFunctor ...@@ -68,11 +69,16 @@ struct AxpbyFunctor
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
if(isfinite(xs[ii]) && isfinite(ys[ii]))
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
else
{ {
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]); out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
bool finite = true;
if(arg_to_check == -1)
finite = (isfinite(xs[ii]) && isfinite(ys[ii]));
if(arg_to_check == 0)
finite = isfinite(xs[ii]);
if(arg_to_check == 1)
finite = isfinite(ys[ii]);
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
} }
...@@ -85,7 +91,8 @@ void multi_tensor_axpby_cuda( ...@@ -85,7 +91,8 @@ void multi_tensor_axpby_cuda(
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float a, float a,
float b) float b,
int arg_to_check)
{ {
using namespace at; using namespace at;
// The output (downscaled) type is always float. // The output (downscaled) type is always float.
...@@ -102,7 +109,8 @@ void multi_tensor_axpby_cuda( ...@@ -102,7 +109,8 @@ void multi_tensor_axpby_cuda(
tensor_lists, tensor_lists,
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(), AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
a, a,
b); ))) b,
arg_to_check); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -65,11 +65,9 @@ struct ScaleFunctor ...@@ -65,11 +65,9 @@ struct ScaleFunctor
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
if(isfinite(incoming_vals[ii]))
out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
else
{ {
out[i] = static_cast<out_t>(incoming_vals[ii]*scale); out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
if(!isfinite(incoming_vals[ii]))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
} }
......
...@@ -68,8 +68,11 @@ Forcing particular layers/functions to a desired type ...@@ -68,8 +68,11 @@ Forcing particular layers/functions to a desired type
I'm still working on a generalizable exposure for this that won't require user-side code divergence I'm still working on a generalizable exposure for this that won't require user-side code divergence
across different ``opt-level``\ s. across different ``opt-level``\ s.
Multiple models/optimizers Multiple models/optimizers/losses
-------------------------- ---------------------------------
Initialization with multiple models/optimizers
**********************************************
``amp.initialize``'s optimizer argument may be a single optimizer or a list of optimizers, ``amp.initialize``'s optimizer argument may be a single optimizer or a list of optimizers,
as long as the output you accept has the same type. as long as the output you accept has the same type.
...@@ -77,35 +80,88 @@ Similarly, the ``model`` argument may be a single model or a list of models, as ...@@ -77,35 +80,88 @@ Similarly, the ``model`` argument may be a single model or a list of models, as
output matches. The following calls are all legal:: output matches. The following calls are all legal::
model, optim = amp.initialize(model, optim,...) model, optim = amp.initialize(model, optim,...)
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...) model, [optim0, optim1] = amp.initialize(model, [optim0, optim1],...)
[model1, model2], optim = amp.initialize([model1, model2], optim,...) [model0, model1], optim = amp.initialize([model0, model1], optim,...)
[model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...) [model0, model1], [optim0, optim1] = amp.initialize([model0, model1], [optim0, optim1],...)
Whenever you invoke a backward pass, the optimizer you should pass to ``amp.scaled_loss`` is whatever Backward passes with multiple optimizers
optimizer owns the parameters for which this particular backward pass is creating gradients. ****************************************
Multiple backward passes per iteration Whenever you invoke a backward pass, the ``amp.scale_loss`` context manager must receive
-------------------------------------- **all the optimizers that own any params for which the current backward pass is creating gradients.**
This is true even if each optimizer owns only some, but not all, of the params that are about to
receive gradients.
If you want to accumulate gradients from multiple losses for the params owned by a given optimizer, If, for a given backward pass, there's only one optimizer whose params are about to receive gradients,
you must invoke ``with amp.scale_loss(..., delay_unscale=True)`` for all backward passes except you may pass that optimizer directly to ``amp.scale_loss``. Otherwise, you must pass the
the last:: list of optimizers whose params are about to receive gradients::
# delay_unscale=True for the first two losses # loss0 accumulates gradients only into params owned by optim0:
with amp.scale_loss(loss1, optimizer, delay_unscale=True) as scaled_loss: with amp.scale_loss(loss0, optim0) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
with amp.scale_loss(loss2, optimizer, delay_unscale=True) as scaled_loss:
# loss1 accumulates gradients only into params owned by optim1:
with amp.scale_loss(loss1, optim1) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
# Don't delay_unscale for the final loss
with amp.scale_loss(loss3, optimizer) as scaled_loss: # loss2 accumulates gradients into some params owned by optim0
# and some params owned by optim1
with amp.scale_loss(loss2, [optim0, optim1]) as scaled_loss:
scaled_loss.backward()
Optionally have Amp use a different loss scaler per-loss
********************************************************
By default, Amp maintains a single global loss scaler that will be used for all backward passes
(all invocations of ``with amp.scale_loss(...)``). No additional arguments to ``amp.initialize``
or ``amp.scale_loss`` are required to use the global loss scaler. The code snippets above with
multiple optimizers/backward passes use the single global loss scaler under the hood,
and they should "just work."
However, you can optionally tell Amp to maintain a loss scaler per-loss, which gives Amp increased
numerical flexibility. This is accomplished by supplying the ``num_losses`` argument to
``amp.initialize`` (which tells Amp how many backward passes you plan to invoke, and therefore
how many loss scalers Amp should create), then supplying the ``loss_id`` argument to each of your
backward passes (which tells Amp the loss scaler to use for this particular backward pass)::
model, [optim0, optim1] = amp.initialize(model, [optim0, optim1], ..., num_losses=3)
with amp.scale_loss(loss0, optim0, loss_id=0) as scaled_loss:
scaled_loss.backward()
with amp.scale_loss(loss1, optim1, loss_id=1) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
optimizer.step()
with amp.scale_loss(loss2, [optim0, optim1], loss_id=2) as scaled_loss:
scaled_loss.backward()
``num_losses`` and ``loss_id``\ s should be specified purely based on the set of
losses/backward passes. The use of multiple optimizers, or association of single or
multiple optimizers with each backward pass, is unrelated.
Gradient accumulation across iterations Gradient accumulation across iterations
--------------------------------------- ---------------------------------------
Pass ``delay_unscale=True`` to ``amp.scale_loss`` until you're ready to ``step()``:: The following should "just work," and properly accommodate multiple models/optimizers/losses, as well as
gradient clipping via the `instructions above`_::
if iter%iters_to_accumulate == 0:
# Every iters_to_accumulate iterations, unscale and step
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# Gradient clipping if desired:
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
optimizer.step()
optimizer.zero_grad()
else:
# Otherwise, accumulate gradients, don't unscale or step.
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
As a minor performance optimization, you can pass ``delay_unscale=True``
to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True``
if you're sure you know what you're doing, because the interaction with gradient clipping and
multiple models/optimizers/losses can become tricky.::
if iter%iters_to_accumulate == 0: if iter%iters_to_accumulate == 0:
# Every iters_to_accumulate iterations, unscale and step # Every iters_to_accumulate iterations, unscale and step
...@@ -114,10 +170,12 @@ Pass ``delay_unscale=True`` to ``amp.scale_loss`` until you're ready to ``step() ...@@ -114,10 +170,12 @@ Pass ``delay_unscale=True`` to ``amp.scale_loss`` until you're ready to ``step()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
else: else:
# Otherwise, just accumulate gradients, don't unscale or step. # Otherwise, accumulate gradients, don't unscale or step.
with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss: with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
.. _`instructions above`:
https://nvidia.github.io/apex/advanced.html#gradient-clipping
Custom data batch types Custom data batch types
----------------------- -----------------------
......
...@@ -15,7 +15,10 @@ is under construction. ...@@ -15,7 +15,10 @@ is under construction.
If you already implemented Amp based on the instructions below, but it isn't behaving as expected, If you already implemented Amp based on the instructions below, but it isn't behaving as expected,
please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help, please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help,
file an issue. `file an issue`_.
.. _`file an issue`:
https://github.com/NVIDIA/apex/issues
``opt_level``\ s and Properties ``opt_level``\ s and Properties
------------------------------- -------------------------------
...@@ -109,9 +112,8 @@ Your incoming model should be FP32 already, so this is likely a no-op. ...@@ -109,9 +112,8 @@ Your incoming model should be FP32 already, so this is likely a no-op.
| |
| |
``O1``: Conservative Mixed Precision ``O1``: Mixed Precision (recommended for typical use)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist
model. Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed model. Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed
in FP16. Blacklist ops that benefit from FP32 precision (for example, softmax) in FP16. Blacklist ops that benefit from FP32 precision (for example, softmax)
...@@ -126,8 +128,8 @@ are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden ...@@ -126,8 +128,8 @@ are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden
| |
| |
``O2``: Fast Mixed Precision ``O2``: "Almost FP16" Mixed Precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``O2`` casts the model weights to FP16, ``O2`` casts the model weights to FP16,
patches the model's ``forward`` method to cast input patches the model's ``forward`` method to cast input
data to FP16, keeps batchnorms in FP32, maintains FP32 master weights, data to FP16, keeps batchnorms in FP32, maintains FP32 master weights,
......
...@@ -4,6 +4,7 @@ import functools as ft ...@@ -4,6 +4,7 @@ import functools as ft
import itertools as it import itertools as it
from apex import amp from apex import amp
from apex.amp import _amp_state
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -60,24 +61,27 @@ class PromoteModule(torch.nn.Module): ...@@ -60,24 +61,27 @@ class PromoteModule(torch.nn.Module):
class TestCache(unittest.TestCase): class TestCache(unittest.TestCase):
def setUp(self): def setUp(self):
self.handle = amp.init(enabled=True)
self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32) self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
common_init(self) common_init(self)
def tearDown(self): def tearDown(self):
self.handle._deactivate() pass
def train_eval_train_test(self, module, t): def train_eval_train_test(self, module, t):
model = module(t).cuda() model = module(t).cuda()
dummy_optimizer = torch.optim.SGD(model.parameters(), lr=1.0) optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
_amp_state.allow_incoming_model_not_fp32 = True
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
_amp_state.allow_incoming_model_not_fp32 = False
def training_step(): def training_step():
for param in model.parameters(): for param in model.parameters():
param.grad = None param.grad = None
loss = model(self.x).sum() loss = model(self.x).sum()
self.handle._default_scaler._loss_scale = 4.0 _amp_state.loss_scalers[0]._loss_scale = 4.0
with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1) self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
...@@ -106,6 +110,8 @@ class TestCache(unittest.TestCase): ...@@ -106,6 +110,8 @@ class TestCache(unittest.TestCase):
# Simulates resuming training after eval # Simulates resuming training after eval
training_step() training_step()
_amp_state.handle._deactivate()
# I could easily have these as a set of for loops in a single test, # I could easily have these as a set of for loops in a single test,
# instead of going for granularity. # instead of going for granularity.
def test_whitelist_module_fp16_weight(self): def test_whitelist_module_fp16_weight(self):
......
...@@ -54,7 +54,7 @@ class TestMultiTensorAxpby(unittest.TestCase): ...@@ -54,7 +54,7 @@ class TestMultiTensorAxpby(unittest.TestCase):
else: else:
out_list = [out.clone().to(out_type)*3.0 for out in y_list] out_list = [out.clone().to(out_type)*3.0 for out in y_list]
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b) applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]), self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ parser.add_argument('--opt-level', type=str) ...@@ -6,6 +6,7 @@ parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true') parser.add_argument('--fused-adam', action='store_true')
parser.add_argument('--use_baseline', action='store_true')
args = parser.parse_args() args = parser.parse_args()
base_file = str(args.opt_level) + "_" +\ base_file = str(args.opt_level) + "_" +\
...@@ -15,16 +16,24 @@ base_file = str(args.opt_level) + "_" +\ ...@@ -15,16 +16,24 @@ base_file = str(args.opt_level) + "_" +\
file_e = "True_" + base_file file_e = "True_" + base_file
file_p = "False_" + base_file file_p = "False_" + base_file
if args.use_baseline:
file_b = "baselines/True_" + base_file
dict_e = torch.load(file_e) dict_e = torch.load(file_e)
dict_p = torch.load(file_p) dict_p = torch.load(file_p)
if args.use_baseline:
dict_b = torch.load(file_b)
torch.set_printoptions(precision=10) torch.set_printoptions(precision=10)
print(file_e) print(file_e)
print(file_p) print(file_p)
if args.use_baseline:
print(file_b)
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): # ugly duplication here...
if not args.use_baseline:
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
loss_e = dict_e["Loss"][n] loss_e = dict_e["Loss"][n]
...@@ -36,3 +45,20 @@ for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): ...@@ -36,3 +45,20 @@ for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
loss_p, loss_p,
dict_e["Speed"][n], dict_e["Speed"][n],
dict_p["Speed"][n])) dict_p["Speed"][n]))
else:
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
loss_e = dict_e["Loss"][n]
loss_p = dict_p["Loss"][n]
loss_b = dict_b["Loss"][n]
assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b)
print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
i_e,
loss_b,
loss_e,
loss_p,
dict_b["Speed"][n],
dict_e["Speed"][n],
dict_p["Speed"][n]))
...@@ -7,12 +7,13 @@ print_banner() { ...@@ -7,12 +7,13 @@ print_banner() {
print_banner "Distributed status: $1" print_banner "Distributed status: $1"
echo $2 echo $2
if [ -n "$2" ] DATADIR=$2
if [ -n "$3" ]
then then
DATADIR="$2" USE_BASELINE=""
else else
# DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/" USE_BASELINE="--use_baseline"
DATADIR="/opt/home/apex/examples/imagenet/"
fi fi
if [ "$1" == "single_gpu" ] if [ "$1" == "single_gpu" ]
...@@ -130,7 +131,7 @@ do ...@@ -130,7 +131,7 @@ do
fi fi
echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR" echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR"
set -x set -x
python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --use_baseline
set +x set +x
done done
done done
......
#!/bin/bash #!/bin/bash
DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
# DATADIR="/opt/home/apex/examples/imagenet/"
cp ../common/* . cp ../common/* .
bash run_test.sh single_gpu $1 bash run_test.sh single_gpu $1 $DATADIR yes
import torch
import argparse
import os
from apex import amp
# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
from apex.parallel import DistributedDataParallel
parser = argparse.ArgumentParser()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
# the 'WORLD_SIZE' environment variable will also be set automatically.
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
# FOR DISTRIBUTED: Set the device according to local_rank.
torch.cuda.set_device(args.local_rank)
# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
# environment variables, and requires that you use init_method=`env://`.
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
torch.manual_seed(torch.distributed.get_rank())
torch.backends.cudnn.benchmark = True
N, D_in, D_out = 64, 1024, 16
# Each process receives its own batch of "fake input data" and "fake target data."
# The "training loop" in each process just uses this fake batch over and over.
# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
# example of distributed data sampling for both training and validation.
x = torch.randn(N, D_in, device='cuda')
y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
if args.distributed:
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
# apex.parallel.DistributedDataParallel.
model = DistributedDataParallel(model)
# torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
# model = torch.nn.parallel.DistributedDataParallel(model,
# device_ids=[args.local_rank],
# output_device=args.local_rank)
loss_fn = torch.nn.MSELoss()
for t in range(500):
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if args.local_rank == 0:
print("final loss = ", loss)
torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank()))
torch.save(list(amp.master_params(optimizer)), "rank{}master.pth".format(torch.distributed.get_rank()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment