Unverified Commit 47144979 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #173 from NVIDIA/api_refactor

Unified mixed precision API + backend performance improvements
parents 1603407b 6644c6e6
# PSA: Unified API for mixed precision tools coming soon!
(as introduced by https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html.
Branch `api_refactor` is tracking my progress. Update as of 2/28: PR-ed in https://github.com/NVIDIA/apex/pull/173. I'd like to clean up the documentation a bit more before final merge.
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline
...@@ -19,31 +14,20 @@ users as quickly as possible. ...@@ -19,31 +14,20 @@ users as quickly as possible.
### amp: Automatic Mixed Precision ### amp: Automatic Mixed Precision
`apex.amp` is a tool designed for ease of use and maximum safety in FP16 training. All potentially unsafe ops are performed in FP32 under the hood, while safe ops are performed using faster, Tensor Core-friendly FP16 math. `amp` also automatically implements dynamic loss scaling. `apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
The intention of `amp` is to be the "on-ramp" to easy FP16 training: achieve all the numerical stability of full FP32 training, with most of the performance benefits of full FP16 training. different flags to `amp.initialize`.
[Python Source and API Documentation](https://github.com/NVIDIA/apex/tree/master/apex/amp)
### FP16_Optimizer
`apex.FP16_Optimizer` wraps an existing Python optimizer and automatically implements master parameters and static or dynamic loss scaling under the hood. [Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
The intention of `FP16_Optimizer` is to be the "highway" for FP16 training: achieve most of the numerically stability of full FP32 training, and almost all the performance benefits of full FP16 training. [API Documentation](https://nvidia.github.io/apex/amp.html)
[API Documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling) [Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/fp16_utils) [DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple) [Moving to the new Amp API] (for users of the deprecated tools formerly called "Amp" and "FP16_Optimizer")
[Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)
The Imagenet and word_language_model directories also contain examples that show manual management of master parameters and static loss scaling.
These manual examples illustrate what sort of operations `amp` and `FP16_Optimizer` are performing automatically.
## 2. Distributed Training ## 2. Distributed Training
...@@ -57,69 +41,60 @@ optimized for NVIDIA's NCCL communication library. ...@@ -57,69 +41,60 @@ optimized for NVIDIA's NCCL communication library.
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed) [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)
The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
mixed precision examples also demonstrate `apex.parallel.DistributedDataParallel`. shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization ### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to `apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN. support synchronized BN.
It reduces stats across processes during multiprocess distributed data parallel It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
training. Synchronous BN has been used in cases where only a small
Synchronous Batch Normalization has been used in cases where only very small local minibatch can fit on each GPU.
number of mini-batch could be fit on each GPU. Allreduced stats increase the effective batch size for the BN layer to the
All-reduced stats boost the effective batch size for sync BN layer to be the global batch size across all processes (which, technically, is the correct
total number of mini-batches across all processes. formulation).
It has improved the converged accuracy in some of our research models. Synchronous BN has been observed to improve converged accuracy in some of our research models.
# Requirements # Requirements
Python 3 Python 3
CUDA 9 or 10 CUDA 9 or newer
PyTorch 0.4 or newer. We recommend to use the latest stable release, obtainable from PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
If you have any problems building, please file an issue.
The cpp and cuda extensions require pytorch 1.0 or newer.
We recommend the latest stable release, obtainable from
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
It's often convenient to use Apex in Docker containers. Compatible options include:
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
# Quick Start # Quick Start
### Linux ### Linux
To build the extension run
```
python setup.py install
```
in the root directory of the cloned repository.
To use the extension For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
``` ```
import apex $ git clone apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
``` ```
### CUDA/C++ extension Apex also supports a Python-only build (required with Pytorch 0.4) via
Apex contains optional CUDA/C++ extensions, installable via
``` ```
python setup.py install [--cuda_ext] [--cpp_ext] $ pip install -v --no-cache-dir .
``` ```
Currently, `--cuda_ext` enables A Python-only build omits:
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use 'apex.normalization.FusedLayerNorm'. - Fused kernels required to use `apex.normalization.FusedLayerNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
`--cpp_ext` enables - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
- C++-side flattening and unflattening utilities that reduce the CPU overhead of `apex.parallel.DistributedDataParallel`. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### Windows support ### Windows support
Windows support is experimental, and Linux is recommended. However, since Apex could be Python-only, there's a good chance the Python-only features "just works" the same way as Linux. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. Windows support is experimental, and Linux is recommended. `python setup.py install --cpp_ext --cuda_ext` may work if you were able to build Pytorch from source
on your system. `python setup.py install` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment,
<!-- make sure to install Apex in that same environment.
reparametrization and RNN API under construction
Current version of apex contains:
3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
4. An experimental and in development flexible RNN API.
-->
from . import fp16_utils
from . import parallel from . import parallel
from . import amp from . import amp
from . import fp16_utils
# For optimizers and normalization there is no Python fallback. # For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error. # Absence of cuda backend is a hard error.
......
# amp: Automatic Mixed Precision # amp: Automatic Mixed Precision
## This README documents the deprecated (pre-unified) API.
## Documentation for the current unified API can be found [here](https://nvidia.github.io/apex/)
amp is an experimental tool to enable mixed precision training in amp is an experimental tool to enable mixed precision training in
PyTorch with _extreme_ simplicity and overall numerical safety. It PyTorch with extreme simplicity and overall numerical safety. It
does so by employing a whitelist / blacklist model: does so by employing a whitelist / blacklist model:
- Any function on the whitelist casts its input arguments to - Any function on the whitelist casts its input arguments to
fp16. These are functions like `torch.conv2d` that can take fp16. These are functions like `torch.conv2d` that can take
......
from .amp import init, half_function, float_function, promote_function,\ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize
from ._amp_state import master_params
# This is a "header object" that allows different amp modules to communicate.
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
# But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
class AmpState(object):
def __init__(self):
self.hard_override=False
# Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState()
def warn_or_err(msg):
if _amp_state.hard_override:
print("Warning: " + msg)
else:
raise RuntimeError(msg)
# I'm not sure if allowing hard_override is a good idea.
# + " If you're sure you know what you're doing, supply " +
# "hard_override=True to amp.initialize.")
# def iter_params(param_groups):
# for group in param_groups:
# for p in group['params']:
# yield p
def master_params(optimizer):
"""
Generator expression that iterates over the params owned by ``optimizer``.
Args:
optimizer: An optimizer previously returned from ``amp.initialize``.
"""
for group in optimizer.param_groups:
for p in group['params']:
yield p
import torch
from torch._six import container_abcs, string_classes
import functools
from ._amp_state import _amp_state, warn_or_err
from .handle import disable_casts
from .scaler import LossScaler
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..optimizers import FusedAdam
from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad:
# This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
def applier(value, fn):
if isinstance(value, torch.Tensor):
return fn(value)
elif isinstance(value, string_classes):
return value
elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value)
else:
return value
def check_models(models):
for model in models:
parallel_type = None
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
parallel_type = "torch.nn.parallel.DistributedDataParallel"
if isinstance(model, apex_DDP):
parallel_type = "apex.parallel.DistributedDataParallel"
if isinstance(model, torch.nn.parallel.DataParallel):
parallel_type = "torch.nn.parallel.DataParallel"
if parallel_type is not None:
raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
"Parallel wrappers should only be applied to the model(s) AFTER \n"
"the model(s) have been returned from amp.initialize.")
def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor":
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor":
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
def check_optimizers(optimizers):
for optim in optimizers:
bad_optim_type = None
if isinstance(optim, FP16_Optimizer_general):
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n"
"soon). You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties).")
def wrap_fused_adam(optimizer, properties):
msg = 'Currently, the usage of FusedAdam is restricted to '\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert properties.master_weights is True, msg
assert properties.cast_model_type is torch.float16, msg
assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
if isinstance(optimizers, torch.optim.Optimizer):
optimizers_was_list = False
optimizers = [optimizers]
elif isinstance(optimizers, list):
optimizers_was_list = True
else:
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module):
models_was_list = False
models = [models]
elif isinstance(models, list):
models_was_list = True
else:
raise TypeError("models must be either a single model or a list of models.")
check_models(models)
check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
if properties.cast_model_type:
if properties.keep_batchnorm_fp32:
for model in models:
convert_network(model, properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)
caster = functools.partial(to_type, properties.cast_model_type)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
return old_fwd(*applier(args, caster),
**applier(kwargs, caster))
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
if properties.master_weights:
for i, optimizer in enumerate(optimizers):
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True,
verbose=False)
else:
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale,
verbose=False)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale)
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(*args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = patch_step(optimizer.step)
if optimizers_was_list:
if models_was_list:
return models, optimizers
else:
return models[0], optimizers
else:
if models_was_list:
return models, optimizers[0]
else:
return models[0], optimizers[0]
from . import compat, rnn_compat, utils, wrap from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides from .lists import functional_overrides, torch_overrides, tensor_overrides
from ._amp_state import _amp_state
from .frontend import *
import functools import functools
import itertools import itertools
import torch import torch
_DECORATOR_HANDLE = None _DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set() _USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set() _USER_PROMOTE_REGISTRY = set()
def _decorator_helper(orig_fn, cast_fn, wrap_fn): def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE handle = _DECORATOR_HANDLE
...@@ -21,19 +25,23 @@ def _decorator_helper(orig_fn, cast_fn, wrap_fn): ...@@ -21,19 +25,23 @@ def _decorator_helper(orig_fn, cast_fn, wrap_fn):
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs) return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper return wrapper
# Decorator form # Decorator form
def half_function(fn): def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn) return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn): def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False) wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn) return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn): def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper) wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn) return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form # Registry form
def register_half_function(module, name): def register_half_function(module, name):
if not hasattr(module, name): if not hasattr(module, name):
...@@ -41,20 +49,23 @@ def register_half_function(module, name): ...@@ -41,20 +49,23 @@ def register_half_function(module, name):
name, module)) name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half)) _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name): def register_float_function(module, name):
if not hasattr(module, name): if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format( raise ValueError('No function named {} in module {}.'.format(
name, module)) name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float)) _USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name): def register_promote_function(module, name):
if not hasattr(module, name): if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format( raise ValueError('No function named {} in module {}.'.format(
name, module)) name, module))
_USER_PROMOTE_REGISTRY.add((module, name)) _USER_PROMOTE_REGISTRY.add((module, name))
# Top-level function to insert _all_ the hooks. # Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE global _DECORATOR_HANDLE
if not enabled: if not enabled:
...@@ -62,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -62,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
return handle return handle
handle = AmpHandle(enable_caching, verbose) handle = AmpHandle(loss_scale, enable_caching, verbose)
# 0) Force-{fp16, fp32} for user-annotated functions # 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_CAST_REGISTRY: for mod, fn, cast_fn in _USER_CAST_REGISTRY:
...@@ -160,4 +171,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -160,4 +171,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg) wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
_amp_state.handle = handle
return handle return handle
import torch
from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err
class Properties(object):
"""
This class has two purposes: to establish a set of default properties,
and to route setting of these attributes through __setattr__ so that (in theory)
they can be checked for consistency with other existing args.
"""
def __init__(self):
self.options = {
"enabled" : False,
"opt_level" : None,
"cast_model_type" : None,
"patch_torch_functions" : False,
"keep_batchnorm_fp32" : None,
"master_weights" : None,
"loss_scale" : 1.0,
# Reserved for future functionality
# "fused_optimizer" : False,
# "enable_ddp_interop" : False,
}
"""
This function allows updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
Currently not intended to be exposed; users are expected to select an opt_level
and apply consistent modifications.
"""
def _update_options_dict(new_options):
for k, v in new_options:
if k in self.options:
self.options[k] = v
else:
raise ValueError("Tried to set unexpected option {}".format(k))
"""
The members of "options" are not direct attributes of self, so access attempts
will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
"""
def __getattr__(self, name):
if "options" in self.__dict__:
options = self.__dict__["options"]
if name in options:
return options[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
def __setattr__(self, name, value):
if "options" in self.__dict__:
if name in self.options:
# print("setting {} {}".format(name, value))
if name == "cast_model_type":
if self.opt_level == "O1" and value is not None:
if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than "
"model weights, so with O1, the model weights themselves "
"should remain FP32. If you wish to cast the model to a "
"different type, use opt_level='O2' or 'O3'. " +
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level != "O1" and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.")
self.options[name] = value
elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
"keep_batchnorm_fp32 was {}".format(keep_batchnorm_fp32))
if value == "False":
self.options[name] = False
elif value == "True":
self.options[name] = True
else:
assert (value is True or value is False or value is None),\
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
"or None, found keep_batchnorm_fp32={}".format(keep_batchnorm_fp32)
self.options[name] = value
elif name == "master_weights":
if self.opt_level == "O1" and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
self.options[name] = value
else:
self.options[name] = float(value)
else:
self.options[name] = value
else:
super(Properties, self).__setattr__(name, value)
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
class O3:
brief = "O3: Pure FP16 training."
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
"so you don't need to change your data pipeline.\n"\
"This mode is useful for establishing a performance ceiling.\n"\
"It's also possible training may 'just work' in this mode.\n"\
"If not, try other optimization levels."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O3"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O2:
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
"your data pipeline.\n"\
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the FP16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = "dynamic"
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O1:
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
"trying mixed precision training for the first time."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O1"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = "dynamic"
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O0:
brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O0"
properties.cast_model_type = torch.float32
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = None
properties.master_weights = False
properties.loss_scale = 1.0
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0()}
# allow user to directly pass Properties struct as well?
def initialize(
models,
optimizers,
enabled=True,
opt_level=None,
cast_model_type=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None
):
"""
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any.
``amp.initialize`` must be called **after** you have finished constructing your model(s) and
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
See `Distributed training`_ in the Imagenet example.
Any property keyword argument that is not ``None`` will be interpreted as a manual override.
To prevent having to rewrite anything else in your script, name the returned models/optimizers
to replace the passed models/optimizers, as in the Usage below.
Args:
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
optimizers (torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
opt_level(str, required): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
loss_scale(float or str, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic".
Returns:
Model(s) and optimizer(s) modified according to the ``opt_level``.
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
also be a list.
Usage::
model, optim = amp.initialize(model, optim,...)
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
[model1, model2], optim = amp.initialize([model1, model2], optim,...)
[model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
# This is not an exhaustive list of the cross product of options that are possible,
# just a set of examples.
model, optim = amp.initialize(model, optim, opt_level="O0")
model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
.. _`Distributed training`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
.. _`Imagenet example`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
"""
if not enabled:
_amp_state.opt_properties = Properties()
return models, optimizers
if opt_level not in opt_levels:
raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'.")
else:
_amp_state.opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}".format(opt_levels[opt_level].brief))
print("Defaults for this optimization level are:")
print(_amp_state.opt_properties.options)
for k, v in _amp_state.opt_properties.options.items():
print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...")
# I chose to have the keyword arguments listed directly in the argument list, so I
# can't use kwargs.items() here.
if enabled is not None:
_amp_state.opt_properties.enabled = enabled
if opt_level is not None:
_amp_state.opt_properties.opt_level = opt_level
if cast_model_type is not None:
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
_amp_state.opt_properties.master_weights = master_weights
if loss_scale is not None:
_amp_state.opt_properties.loss_scale = loss_scale
print("After processing overrides, optimization options are:")
for k, v in _amp_state.opt_properties.options.items():
print("{:22} : {}".format(k, v))
return _initialize(models, optimizers, _amp_state.opt_properties)
# TODO: is this necessary/useful?
# def check_option_consistency(enabled=True,
# opt_level=None,
# cast_model_type=None,
# patch_torch_functions=None,
# keep_batchnorm_fp32=None,
# master_weights=None,
# loss_scale=None,
# enable_ddp_interop=None,
# hard_override=False):
# """
# Utility function that enables users to quickly check if the option combination they intend
# to use is permitted. ``check_option_consistency`` does not require models or optimizers
# to be constructed, and can be called at any point in the script. ``check_option_consistency``
# is totally self-contained; it does not set any amp global state or affect anything outside
# of itself.
# """
#
# if not enabled:
# return
#
# if opt_level not in opt_levels:
# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
# else:
# opt_properties = opt_levels[opt_level](Properties())
# print("Selected optimization level {}", opt_levels[opt_level].brief)
# print("Defaults for this optimization level are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
#
# print("Processing user overrides (additional kwargs that are not None)...")
# for k, v in kwargs:
# if k not in _amp_state.opt_properties.options:
# raise RuntimeError("Unexpected kwarg {}".format(k))
# if v is not None:
# setattr(opt_properties, k, v)
#
# print("After processing overrides, optimization options are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
import contextlib import contextlib
import logging import logging
import warnings import warnings
import torch
from . import utils from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager
def scale_loss(loss,
optimizer,
model=None,
delay_unscale=False):
"""
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
and unscaled, so that ``optimizer.step()`` can be called.
.. note::
If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
any FP16 gradients are copied to FP32 master gradients before being unscaled.
``optimizer.step()`` will then apply the unscaled master gradients to the master params.
.. warning::
If Amp is using explicit FP32 master params, only the FP32 master gradients will be
unscaled. The direct ``.grad`` attributes of any FP16
model params will remain scaled after context manager exit.
This subtlety affects gradient clipping. See "Gradient clipping" under
"Advanced use cases" for best practices.
Args:
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
manager yields is simply ``loss.float()*loss_scale``, so in principle
``loss`` could have more than one element, as long as you call
``backward()`` on ``scaled_loss`` appropriately within the context manager body.
optimizer: Must be an optimizer returned from an earlier call to ``amp.initialize``.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations.
delay_unscale(bool, default=False): Don't unscale the gradients or perform model->master
gradient copies on context manager exit. "Advanced use cases" illustrates
situations where this is necessary.
.. warning::If ``True``, ``optimizer.step()`` cannot be
called yet after context manager exit, and must wait for another, later backward context
manager invocation with ``delay_unscale`` left to False.
See "Advanced use cases" for examples.
"""
if not _amp_state.opt_properties.enabled:
yield loss
return
if optimizer.loss_scaler is None:
raise RuntimeError("optimizer passed to scale_loss does not have a loss_scaler.")
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizer, FP16_Optimizer_for_fused):
loss_scale = optimizer.cur_scale
else:
loss_scale = optimizer.loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights)
and (not optimizer.loss_scaler.dynamic)
and loss_scale == 1.0):
yield loss.float()
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
return
yield (loss.float())*loss_scale
# this isn't pretty but it unifies things. Once I deprecate the old API entirely,
# I will have freedom to clean this up. Maybe instead of wrapping optimizers,
# I can simply construct a set of attributes (e.g. master params) and assign them
# directly to optimizer instances.
if not delay_unscale:
# The FP16_Optimizer for FusedAdam will take care of unscaling as part of
# its step() method.
if not isinstance(optimizer, FP16_Optimizer_for_fused):
if isinstance(optimizer, FP16_Optimizer_general):
optimizer.update_master_grads()
else:
optimizer.loss_scaler.clear_overflow_state()
optimizer.loss_scaler.unscale(
master_params(optimizer),
master_params(optimizer),
loss_scale)
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale()
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logger = logging.getLogger('apex.amp')
logger.warning("Gradient overflow. Skipping step, reducing " +
"loss scale to {}".format(optimizer.loss_scaler.loss_scale()))
optimizer.step = optimizer_step
optimizer.step = skip_step
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
# Free function version of AmpHandle.disable_casts, another step on the
# path to removing the concept of "AmpHandle"
@contextlib.contextmanager
def disable_casts():
_amp_state.handle._is_active = False
yield
_amp_state.handle._is_active = True
class AmpHandle(object): class AmpHandle(object):
def __init__(self, enable_caching=True, verbose=False): def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
self._enable_caching = enable_caching self._enable_caching = enable_caching
self._verbose = verbose self._verbose = verbose
self._cache = dict() self._cache = dict()
self._default_scaler = LossScaler() self._default_scaler = LossScaler(loss_scale)
self._is_active = True self._is_active = True
self._all_wrappers = [] self._all_wrappers = []
...@@ -43,8 +162,12 @@ class AmpHandle(object): ...@@ -43,8 +162,12 @@ class AmpHandle(object):
loss_scale = self._default_scaler.loss_scale() loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale yield loss * loss_scale
should_skip = self._default_scaler.unscale_and_update( self._default_scaler.clear_overflow_state()
optimizer.param_groups, loss_scale) self._default_scaler.unscale(
master_params(optimizer),
master_params(optimizer),
loss_scale)
should_skip = self._default_scaler.update_scale()
if should_skip: if should_skip:
optimizer_step = optimizer.step optimizer_step = optimizer.step
def skip_step(): def skip_step():
...@@ -106,5 +229,8 @@ class NoOpHandle(object): ...@@ -106,5 +229,8 @@ class NoOpHandle(object):
def verbose(self): def verbose(self):
return False return False
def _clear_cache(self):
pass
def _deactivate(self): def _deactivate(self):
pass pass
...@@ -2,7 +2,7 @@ import contextlib ...@@ -2,7 +2,7 @@ import contextlib
import logging import logging
import warnings import warnings
from .scaler import LossScaler, iter_params from .scaler import LossScaler, master_params
import numpy as np import numpy as np
...@@ -27,7 +27,7 @@ class OptimWrapper(object): ...@@ -27,7 +27,7 @@ class OptimWrapper(object):
# all mixed together. # all mixed together.
cached_grads = [] cached_grads = []
if self._loss_idx > 0: if self._loss_idx > 0:
for p in iter_params(self._optimizer.param_groups): for p in master_params(self._optimizer):
if p.grad is not None: if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone()) cached_grads.append(p.grad.data.detach().clone())
else: else:
...@@ -37,12 +37,16 @@ class OptimWrapper(object): ...@@ -37,12 +37,16 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale() loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale yield loss * loss_scale
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update( self._cur_loss_scaler().clear_overflow_state()
self._optimizer.param_groups, loss_scale) self._cur_loss_scaler().unscale(
master_params(self._optimizer),
master_params(self._optimizer),
loss_scale)
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
self._loss_idx += 1 self._loss_idx += 1
if len(cached_grads) > 0: if len(cached_grads) > 0:
for p, cached_grad in zip(iter_params(self._optimizer.param_groups), for p, cached_grad in zip(master_params(self._optimizer),
cached_grads): cached_grads):
if cached_grad is not None: if cached_grad is not None:
p.grad.data.add_(cached_grad) p.grad.data.add_(cached_grad)
......
import torch import torch
import logging import logging
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params
from itertools import product
# from apex_C import scale_check_overflow # from apex_C import scale_check_overflow
def scale_check_overflow_python(d_grads, scale): def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
try: if check_overflow:
cpu_sum = float(d_grads.float().sum()) cpu_sum = float(model_grad.float().sum())
except RuntimeError as instance:
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
d_grads.mul_(scale)
if master_grad is not model_grad: # copy_ probably internally short-circuits this
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False return False
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_fp16_grad = False warned_unscaling_non_fp32_grad = False
has_fused_kernel = False has_fused_kernel = False
def __init__(self): def __init__(self,
self._loss_scale = 2.**16 loss_scale,
init_scale=2.**16,
scale_factor=2.,
scale_window=2000):
if loss_scale == "dynamic":
self.dynamic = True
self._loss_scale = init_scale
else:
self.dynamic = False
self._loss_scale = loss_scale
self._max_loss_scale = 2.**24 self._max_loss_scale = 2.**24
self._scale_seq_len = 2000 self._scale_seq_len = scale_window
self._unskipped = 0 self._unskipped = 0
self._has_overflow = False self._has_overflow = False
try:
import amp_C
LossScaler.has_fused_kernel = True
LossScaler.scale_check_overflow_cuda = amp_C.scale_check_overflow
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
except ImportError as err: if multi_tensor_applier.available:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
else:
if not LossScaler.warned_no_fused_kernel: if not LossScaler.warned_no_fused_kernel:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex " print("Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ", "possibly because apex was installed without --cuda_ext --cpp_ext. "
err) "Using Python fallback. Original ImportError was: ",
multi_tensor_applier.import_err)
LossScaler.has_fused_kernel = False LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True LossScaler.warned_no_fused_kernel = True
def loss_scale(self): def loss_scale(self):
return self._loss_scale return self._loss_scale
def unscale_and_update(self, param_groups, scale): def unscale_grads_python(self, model_grads, master_grads, scale):
if LossScaler.has_fused_kernel: for model, master in zip(model_grads, master_grads):
self._overflow_buf.zero_() if model is not None:
self._has_overflow = False if not LossScaler.warned_unscaling_non_fp32_grad:
for p in iter_params(param_groups): if master.type() != "torch.cuda.FloatTensor":
if p.grad is not None: logger = logging.getLogger("apex.amp")
if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor": logger.warning(
LossScaler.scale_check_overflow_cuda(p.grad.data, "Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python(
model,
1./scale, 1./scale,
self._overflow_buf, master,
p.grad.data) self.dynamic)
if self._has_overflow and self.dynamic:
break
def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale):
if self._has_overflow:
return
# Lots of defensive list processing going on here. Way more less efficient than
# consuming the iterator directly. Need to examine Python overhead.
model_master_params = [(model, master) for model, master
in zip(model_params, master_params)] # some of these may be None
if LossScaler.has_fused_kernel:
# TODO: Make these lists permanent attributes of self, so they don't need to be created
# or garbage collected. Profiler shows that garbage collection overhead may be
# substantial (200-300 usec).
# This may be tricky because right now the lists need to be packed densely.
# Maybe this could be handled within the multi_tensor_apply wrapper
# (allow some Tensors to be None using at::optional).
src_dst_pairs = {torch.float16 : {torch.float16 : [[],[]], torch.float32 : [[],[]]},
torch.float32 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}}
for model, master in model_master_params:
# Sync the None-ness of model and master params
if model.grad is None and master.grad is not None:
master.grad = None
if model.grad is not None and master.grad is None:
master.grad = torch.empty_like(master)
if model.grad is not None:
if model.grad is master.grad and scale == 1.0 and not self.dynamic:
continue
else: else:
if (p.grad.data.type() != "torch.cuda.FloatTensor" src_dst_pairs[model.dtype][master.dtype][0].append(model.grad.data)
and not LossScaler.warned_fp16_grad): src_dst_pairs[model.dtype][master.dtype][1].append(master.grad.data)
logger = logging.getLogger("apex.amp")
logger.warning("Incoming grads are not fp32 (not master grads). " assert len(src_dst_pairs[torch.float32][torch.float16][0]) == 0, "The loss scaler is "\
"Downscaling non-fp32 grads may indicate an error. " "being asked to unscale FP32 model gradients into FP16 master gradients. This is "\
"almost certainly an error."
for src, dst in product((torch.float16, torch.float32),
(torch.float16, torch.float32)):
if len(src_dst_pairs[src][dst][0]) > 0:
if not LossScaler.warned_unscaling_non_fp32_grad and dst is torch.float16:
print("Warning: unscaling grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True # Setting this to True unconditionally allows the possibility of an escape
self._has_overflow = scale_check_overflow_python(p.grad.data, # if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(
LossScaler.multi_tensor_scale_cuda,
self._overflow_buf,
src_dst_pairs[src][dst],
1./scale) 1./scale)
if self._has_overflow: else:
break # Sync the None-ness of model and master params.
all_same = True
for model, master in model_master_params:
if model.grad is None and master.grad is not None:
master.grad = None
if model.grad is not None and master.grad is None:
master.grad = torch.empty_like(master)
if model.grad is not master.grad:
all_same = False
if scale == 1.0 and all_same and not self.dynamic:
return
# TODO: Make these lists permanent attributes of self, so they don't need to be created
# or garbage collected?
model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None]
master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None]
self.unscale_grads_python(model_grads, master_grads, scale)
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and not self._has_overflow: if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() self._has_overflow = self._overflow_buf.item()
if self._has_overflow: # Separate so unscale() can be called more that once before updating.
def update_scale(self):
if self._has_overflow and self.dynamic:
should_skip = True should_skip = True
self._loss_scale /= 2. self._loss_scale /= 2.
self._unskipped = 0 self._unskipped = 0
...@@ -80,13 +167,8 @@ class LossScaler(object): ...@@ -80,13 +167,8 @@ class LossScaler(object):
should_skip = False should_skip = False
self._unskipped += 1 self._unskipped += 1
if self._unskipped == self._scale_seq_len: if self._unskipped == self._scale_seq_len and self.dynamic:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.) self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0 self._unskipped = 0
return should_skip return should_skip
def iter_params(param_groups):
for group in param_groups:
for p in group['params']:
yield p
from . import compat from . import compat
from . import utils from . import utils
from ._amp_state import _amp_state
from . import rnn_compat from . import rnn_compat
import functools import functools
...@@ -38,10 +39,16 @@ def cached_cast(mod, fn, cast_fn, handle, ...@@ -38,10 +39,16 @@ def cached_cast(mod, fn, cast_fn, handle,
utils.set_func_save(handle, mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper` # `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
# is on the new API and I am free to get rid of handle, I can clean this up.
def make_promote_wrapper(orig_fn, cast_fn, handle=None): def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1: if len(types) <= 1:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
...@@ -66,6 +73,9 @@ def sequence_promote(mod, fn, handle, verbose=False): ...@@ -66,6 +73,9 @@ def sequence_promote(mod, fn, handle, verbose=False):
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(seq, *args, **kwargs): def wrapper(seq, *args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(seq, *args, **kwargs)
types = set([utils.type_string(x) for x in seq]) types = set([utils.type_string(x) for x in seq])
if len(types) <= 1: if len(types) <= 1:
return orig_fn(seq, *args, **kwargs) return orig_fn(seq, *args, **kwargs)
...@@ -87,6 +97,9 @@ def promote_match_arg0(mod, fn, handle, verbose=False): ...@@ -87,6 +97,9 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs): def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0) assert compat.is_tensor_like(arg0)
if not _amp_state.handle.is_active():
return orig_fn(arg0, *args, **kwargs)
if utils.type_string(arg0) == 'HalfTensor': if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half cast_fn = utils.maybe_half
elif utils.type_string(arg0) == 'FloatTensor': elif utils.type_string(arg0) == 'FloatTensor':
...@@ -226,6 +239,9 @@ def new_rnn_cast(fn, handle, verbose=False): ...@@ -226,6 +239,9 @@ def new_rnn_cast(fn, handle, verbose=False):
assert len(args) == 9 assert len(args) == 9
assert len(kwargs) == 0 assert len(kwargs) == 0
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
if isinstance(args[6], bool): if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case params_idx = 2 # Not PackedSequence case
else: else:
......
...@@ -4,7 +4,8 @@ from torch.autograd import Variable ...@@ -4,7 +4,8 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler from ..amp.scaler import LossScaler
from ..multi_tensor_apply import multi_tensor_applier
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
# TODO: Update overflow check + downscale to use Carl's fused kernel. # TODO: Update overflow check + downscale to use Carl's fused kernel.
...@@ -39,7 +40,7 @@ class FP16_Optimizer(object): ...@@ -39,7 +40,7 @@ class FP16_Optimizer(object):
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`LossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`LossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`LossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way. ``init_optimizer`` is expected to have been constructed in the ordinary way.
...@@ -154,6 +155,18 @@ class FP16_Optimizer(object): ...@@ -154,6 +155,18 @@ class FP16_Optimizer(object):
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.all_fp16_params = []
for group in self.fp16_groups:
self.all_fp16_params += group
self.all_fp32_from_fp16_params = []
for group in self.fp32_from_fp16_groups:
self.all_fp32_from_fp16_params += group
self.all_fp32_from_fp32_params = []
for group in self.fp32_from_fp32_groups:
self.all_fp32_from_fp32_params += group
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors: # alternative way to cast per-param state tensors:
...@@ -162,9 +175,9 @@ class FP16_Optimizer(object): ...@@ -162,9 +175,9 @@ class FP16_Optimizer(object):
if dynamic_loss_scale: if dynamic_loss_scale:
self.dynamic_loss_scale = True self.dynamic_loss_scale = True
if dynamic_loss_args is not None: if dynamic_loss_args is not None:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) self.loss_scaler = LossScaler("dynamic", **dynamic_loss_args)
else: else:
self.loss_scaler = DynamicLossScaler() self.loss_scaler = LossScaler("dynamic")
else: else:
self.dynamic_loss_scale = False self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale) self.loss_scaler = LossScaler(static_loss_scale)
...@@ -174,6 +187,12 @@ class FP16_Optimizer(object): ...@@ -174,6 +187,12 @@ class FP16_Optimizer(object):
self.clip_grad_norm = clip_grad_norm self.clip_grad_norm = clip_grad_norm
# TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._dummy_overflow_buf = torch.cuda.IntTensor([0]);
def maybe_print(self, msg): def maybe_print(self, msg):
if self.verbose: if self.verbose:
print(msg) print(msg)
...@@ -210,35 +229,44 @@ class FP16_Optimizer(object): ...@@ -210,35 +229,44 @@ class FP16_Optimizer(object):
param.grad.detach_() # as in torch.optim.optimizer.zero_grad() param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_() param.grad.zero_()
def _check_overflow(self): # Should not be used anymore.
params = [] # def _check_overflow(self):
for group in self.fp16_groups: # params = []
for param in group: # for group in self.fp16_groups:
params.append(param) # for param in group:
for group in self.fp32_from_fp32_groups: # params.append(param)
for param in group: # for group in self.fp32_from_fp32_groups:
params.append(param) # for param in group:
self.overflow = self.loss_scaler.has_overflow(params) # params.append(param)
# self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False): # def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow) # self.loss_scaler.update_scale(has_overflow)
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
if multi_tensor_applier.available:
if len(self.all_fp16_params) > 0:
multi_tensor_applier(
self.multi_tensor_scale,
self._dummy_overflow_buf,
[self.all_fp32_from_fp16_params, self.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group) master_params_to_model_params(fp16_group, fp32_from_fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable # To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def _model_grads_to_master_grads(self): # def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): # for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) # model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
def _downscale_master(self): # def _downscale_master(self):
if self.loss_scale != 1.0: # if self.loss_scale != 1.0:
for group in self.optimizer.param_groups: # for group in self.optimizer.param_groups:
for param in group['params']: # for param in group['params']:
if param.grad is not None: # if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale) # param.grad.data.mul_(1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2): def clip_master_grads(self, max_norm, norm_type=2):
""" """
...@@ -366,18 +394,23 @@ class FP16_Optimizer(object): ...@@ -366,18 +394,23 @@ class FP16_Optimizer(object):
http://pytorch.org/docs/master/optim.html#optimizer-step-closure http://pytorch.org/docs/master/optim.html#optimizer-step-closure
""" """
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale()
self._update_scale(self.overflow) # To consider: Should this be in step(), or update_master_grads? It works either way,
# but I should make it consistent with the Amp control flow, which updates the scale
# during backward context manager exit.
# self._update_scale(self.overflow)
if self.overflow: if self.overflow:
print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" print("Gradient overflow. Skipping step, reducing " +
.format(scale, self.loss_scale)) "loss scale to {}".format(self.loss_scaler.loss_scale()))
return return
if closure is not None: if closure is not None:
retval = self._step_with_closure(closure) retval = self._step_with_closure(closure)
else: else:
# torch.cuda.nvtx.range_push("pytorch optimizer step")
retval = self.optimizer.step() retval = self.optimizer.step()
# torch.cuda.nvtx.range_pop()
self._master_params_to_model_params() self._master_params_to_model_params()
...@@ -409,10 +442,10 @@ class FP16_Optimizer(object): ...@@ -409,10 +442,10 @@ class FP16_Optimizer(object):
# closure() and return the loss. # closure() and return the loss.
temp_loss = closure() temp_loss = closure()
while(self.overflow): while(self.overflow):
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale()
self._update_scale(self.overflow) # self._update_scale(self.overflow) # now done at the end of backward
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " print("OVERFLOW within closure! Skipping step, reducing loss scale to {}".format(
"reducing to {}".format(scale, self.loss_scale)) self.loss_scaler.loss_scale()))
temp_loss = closure() temp_loss = closure()
return temp_loss return temp_loss
...@@ -480,22 +513,48 @@ class FP16_Optimizer(object): ...@@ -480,22 +513,48 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) scaled_loss = loss.float()*self.loss_scaler.loss_scale()
scaled_loss.backward(retain_graph=retain_graph)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
def update_master_grads(self): def update_master_grads(self):
# torch.cuda.nvtx.range_push("update_master_grads")
""" """
Copy the ``.grad`` attribute from stored references to fp16 parameters to Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly the ``.grad`` attribute of the fp32 master parameters that are directly
updated by the optimizer. :attr:`update_master_grads` only needs to be called if updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
""" """
if self.dynamic_loss_scale: # if self.dynamic_loss_scale:
self._check_overflow() # self._check_overflow()
if self.overflow: return # if self.overflow: return
self._model_grads_to_master_grads() # self._model_grads_to_master_grads()
self._downscale_master() # self._downscale_master()
# Use the one-shot multi-tensor apply kernel
self.loss_scaler.clear_overflow_state()
if len(self.all_fp16_params) > 0:
# print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params])
self.loss_scaler.unscale(
self.all_fp16_params,
self.all_fp32_from_fp16_params,
self.loss_scaler.loss_scale())
# print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp16_params])
if len(self.all_fp32_from_fp32_params) > 0:
# print("Model grads before")
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
self.loss_scaler.unscale(
self.all_fp32_from_fp32_params,
self.all_fp32_from_fp32_params,
self.loss_scaler.loss_scale())
# print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
# quit()
self.overflow = self.loss_scaler.update_scale()
# torch.cuda.nvtx.range_pop()
def inspect_master_grad_data(self): def inspect_master_grad_data(self):
""" """
...@@ -533,10 +592,10 @@ class FP16_Optimizer(object): ...@@ -533,10 +592,10 @@ class FP16_Optimizer(object):
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
return self.loss_scaler.loss_scale return self.loss_scaler.loss_scale()
def _set_loss_scale(self, value): def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value self.loss_scaler._loss_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
......
from .multi_tensor_apply import MultiTensorApply
multi_tensor_applier = MultiTensorApply(2048*32)
import torch
class MultiTensorApply(object):
available = False
warned = False
def __init__(self, chunk_size):
try:
import amp_C
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:
MultiTensorApply.available = False
MultiTensorApply.import_err = err
def check_avail(self):
if MultiTensorApply.available == False:
raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:",
MultiTensorApply.import_err)
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
self.check_avail()
return op(self.chunk_size,
noop_flag_buffer,
tensor_lists,
*args)
...@@ -7,7 +7,6 @@ try: ...@@ -7,7 +7,6 @@ try:
lib = ctypes.cdll.LoadLibrary(None) lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p] lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input): def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor': if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough # 16384 is half 2 if you stare at it long enough
...@@ -27,7 +26,7 @@ except TypeError as err: ...@@ -27,7 +26,7 @@ except TypeError as err:
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer. :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Design to be used in the same way but support only fused optimizers in apex. Designed only to wrap apex.optimizers.FusedAdam.
Refer to apex.fp16_utils documents for more information. Refer to apex.fp16_utils documents for more information.
Example:: Example::
...@@ -97,7 +96,7 @@ class FP16_Optimizer(object): ...@@ -97,7 +96,7 @@ class FP16_Optimizer(object):
if dynamic_loss_args is not None: if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.") raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True self.dynamic_loss_scale = True
self.cur_scale = 2**32 self.cur_scale = 2**16
self.cur_iter = 0 self.cur_iter = 0
self.last_overflow_iter = -1 self.last_overflow_iter = -1
self.scale_factor = 2 self.scale_factor = 2
...@@ -180,7 +179,7 @@ class FP16_Optimizer(object): ...@@ -180,7 +179,7 @@ class FP16_Optimizer(object):
def backward(self, loss): def backward(self, loss):
""" """
:attr:`backward` performs the following conceptual steps: :attr:`backward` performs the following steps:
1. fp32_loss = loss.float() 1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale 2. scaled_loss = fp32_loss*loss_scale
......
#include <torch/extension.h> #include <torch/extension.h>
void scale_check_overflow_cuda(const at::Tensor& grads, void multi_tensor_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void scale_check_overflow_cuda(
const at::Tensor& grads,
float scale, float scale,
const at::Tensor& d_buf, const at::Tensor& d_buf,
const at::Tensor& downscaled_grads); const at::Tensor& downscaled_grads);
void scale_check_overflow(at::Tensor grads, void scale_check_overflow(
at::Tensor grads,
float scale, float scale,
at::Tensor overflow_buf, at::Tensor overflow_buf,
at::Tensor downscaled_grads) at::Tensor downscaled_grads)
...@@ -27,4 +35,6 @@ void scale_check_overflow(at::Tensor grads, ...@@ -27,4 +35,6 @@ void scale_check_overflow(at::Tensor grads,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors"); m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
} }
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorList
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorList<depth> tl;
auto stream = at::cuda::getCurrentCUDAStream();
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
noop_flag.data<int>(),
tl,
callable,
args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
}
}
}
}
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 512
#define ILP 4
template<typename in_t, typename out_t>
struct ScaleFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorList<2>& tl,
float scale)
{
__shared__ int noop_smem;
if(threadIdx.x == 0)
noop_smem = *noop_gmem;
__syncthreads();
if(noop_smem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_vals[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_vals[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
incoming_vals[ii] = static_cast<float>(in[i]);
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
if(isfinite(incoming_vals[ii]))
out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
else
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
// *noop_gmem = 1 is NOT guaranteed to be seen immediately by thread 0. I wonder if
// we can rig block-wide and grid-wide short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
if(threadIdx.x == 0)
noop_smem = *noop_gmem;
__syncthreads();
if(noop_smem == 1)
break;
}
}
};
void multi_tensor_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale)
{
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
"multi_tensor_scale_cuda",
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].type().scalarType())
{
case at::ScalarType::Half:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
AT_ERROR("multi_tensor_scale_cuda not implemented for output type = ",
tensor_lists[1][0].type().toString());
}
});
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
...@@ -6,14 +6,14 @@ ...@@ -6,14 +6,14 @@
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#define BLOCK_SIZE 1024 #define BLOCK_SIZE 256
#define NBLOCKS 160 #define NBLOCKS 160*4
#define ILP 4
// It makes sense to lock the output type to fp32 because the downscaled // It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their // grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32. // gradients should always be fp32).
// This can be optimized with ILP but it's fine for now.
template<typename in_t> template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in, __global__ void scale_reduce_overflow(in_t* in,
float* out, float* out,
...@@ -22,12 +22,12 @@ __global__ void scale_reduce_overflow(in_t* in, ...@@ -22,12 +22,12 @@ __global__ void scale_reduce_overflow(in_t* in,
volatile int* overflow_global) volatile int* overflow_global)
{ {
__shared__ int overflow; __shared__ int overflow;
float incoming_vals[4];
int tid = blockIdx.x*blockDim.x + threadIdx.x;
int stride = gridDim.x*blockDim.x;
// Non-divergent exit condition for the __syncthreads // Non-divergent exit condition for the __syncthreads
for(int i = tid; i - threadIdx.x < n; i += stride) for(int chunk_start = blockIdx.x*blockDim.x*ILP;
chunk_start < n;
chunk_start += gridDim.x*blockDim.x*ILP)
{ {
if(threadIdx.x == 0) if(threadIdx.x == 0)
overflow = *overflow_global; overflow = *overflow_global;
...@@ -37,19 +37,27 @@ __global__ void scale_reduce_overflow(in_t* in, ...@@ -37,19 +37,27 @@ __global__ void scale_reduce_overflow(in_t* in,
if(overflow == 1) if(overflow == 1)
break; break;
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_vals[ii] = 0;
int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(i < n) if(i < n)
incoming_vals[ii] = static_cast<float>(in[i]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
float incoming_val = static_cast<float>(in[i]); int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(isfinite(incoming_val)) if(i < n)
out[i] = incoming_val*scale; if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale;
else else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok. *overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration. } // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads. } // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast. } // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
}
}
}
void scale_check_overflow_cuda void scale_check_overflow_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment