"vscode:/vscode.git/clone" did not exist on "3783f3f4b51ac8f17b0a2e36f2ba06184c45ee47"
Commit e733e78c authored by Carl Case's avatar Carl Case Committed by Michael Carilli
Browse files

Initial support for automatic mixed precision

parent a3059288
from . import amp
from . import RNN
from . import reparameterization
from . import fp16_utils
......
# amp: Automatic Mixed Precision
amp is an experimental tool to enable mixed precision training in
PyTorch with _extreme_ simplicity and overall numerical safety. It
does so by employing a whitelist / blacklist model:
- Any function on the whitelist casts its input arguments to
fp16. These are functions like `torch.conv2d` that can take
advantage of TensorCore execution.
- Any function on the blacklist casts its input arguments to
fp32. These are functions like `torch.exp` or loss functions that
have trouble with the numerical properties of fp16.
- Any other function passes along its input types to its outputs. Care
is taken so that multi-argument functions or methods
(e.g. `torch.tensor.__add__`) can handle mixed type inputs. They
simply promote all inputs to have the widest type of any input.
The PyTorch hooks that enable the necessary casts are at the low-level
functional interface to PyTorch, so even custom layers will work with
amp, so long as they are built out of PyTorch functions and methods.
In particular, amp hooks into all of the following:
- Functions in the top-level `torch` namespace
- Functions in the `torch.nn.functional` namespace
- Methods on `Tensor` objects (GPU only, fp16 and fp32)
- Custom support for RNNs, even though they have no direct functional
interface:
- Recurrent cells: `torch.nn.{RNNCell, LSTMCell, GRUCell}`
- Recurrent layers: `torch.nn.{RNN, LSTM, GRU}`
In a few limited cases, amp needs help finding custom user-defined
functions that use low-level PyTorch features. In those cases, a
simple annotation is sufficient; this is described below.
## Installation and Requirements
amp is developed on Python 3.6 and PyTorch 0.4. It takes care to be
backwards-compatible with PyTorch 0.3, but users are _highly_
encouraged to upgrade.
amp is installed during normal apex installation, so refer to the
top-level README for more on installation.
## Usage and Getting Started
In the normal case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
per-iteration state.
#### 1. Enable amp
```python
from apex import amp
amp_handle = amp.enable()
```
`amp.enable()` takes two arguments, and the defaults are _highly_
recommended. The first, `enable_caching` (default=True), indicates
whether amp should cache fp16 casts of model parameters on a
per-iteration basis. This prevents things like RNN cells used inside a
loop from casting their weight matrices over and over. The second,
`verbose` (default=False) toggles whether to print out every cast that
occurs. Useful for debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loops that looks like:
```python
# ... do a bunch of stuff to compute a loss
loss.backward()
optimizer.step()
# ...finish the iteration
```
To use amp, you need only tell it where backprop occurs:
```python
# ... same as before
with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ... same as before
```
This context manager allows amp to:
1. Use automatic loss scaling to best use fp16 range
2. Clear its cache of casted parameters before the next optimizer step
Note that it is _possible_ to use amp without step 2. In which case,
you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.)
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than steps one and two
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
However, any custom C++ or CUDA code is outside of amp's (default)
view of things. For example, suppose I implemented a new recurrent
cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
```python
def fru(input, hidden, weight, bias):
# ... call to CUDA code
```
amp exposes two functions to handle this case: `register_fp16` and
`register_fp32`. These add the given function to the white or
blacklist, respectively. You can use them as a decorator:
```python
@amp.register_fp16
def fru(input, hidden, weight, bias):
# ...
```
or as a library call:
```python
from apex import amp
amp.register_fp16(custom_module.fru)
amp.enable()
```
Note that the function must be registered before the call to
`amp.enable()`. The library call makes this simple. If the function is
annotated, then you must ensure its module is loaded before the call
to `amp.enable()`. Furthermore, this does not (yet) work with class
methods, only free functions.
from torch.utils.ffi import _wrap_function
from ._scale_lib import lib as _lib, ffi as _ffi
__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
if callable(fn):
locals[symbol] = _wrap_function(fn, _ffi)
else:
locals[symbol] = fn
__all__.append(symbol)
_import_symbols(locals())
from .amp import enable, register_half, register_float
VERSION = (0, 1, 0)
__version__ = '.'.join(map(str, VERSION))
from . import compat, utils, wrap
from .handle import AmpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect
import itertools
import torch
_USER_REGISTRY = set()
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `enable()`
def register_half(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
return fn
def register_float(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
return fn
# Top-level function to insert _all_ the hooks.
def enable(enable_caching=True, verbose=False):
handle = AmpHandle(enable_caching)
# 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_REGISTRY:
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose)
_USER_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
torch_overrides,
tensor_overrides]
cast_table = [('FP16_FUNCS', utils.maybe_half),
('FP32_FUNCS', utils.maybe_float)]
for module, (list_name, cast_fn) in itertools.product(override_modules,
cast_table):
for fn in getattr(module, list_name):
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
try_caching, verbose)
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor():
for fn in tensor_overrides.FP16_FUNCS:
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS:
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
handle, try_caching=False, verbose=verbose)
# 2) Enable type-promotion on multi-arg functions and methods.
# NB: special handling for sequence fns (e.g. `torch.cat`).
promote_modules = [torch_overrides, tensor_overrides]
promote_table = [('CASTS', wrap.promote),
('SEQUENCE_CASTS', wrap.sequence_promote)]
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
promote_table):
for fn in getattr(promote_mod, list_name):
promote_fn(promote_mod.MODULE, fn, verbose)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor():
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
torch.cuda.HalfTensor],
promote_table):
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, verbose)
if compat.tensor_is_float_tensor():
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, verbose)
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS,
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, verbose)
if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, verbose)
# 5) Special handling to whitelist RNN cell backend impls.
for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']:
wrap.cached_cast(torch.nn.backends.thnn.backend, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose)
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
return handle
import torch
# True for post-0.4, when Variables/Tensors merged.
def variable_is_tensor():
v = torch.autograd.Variable()
return isinstance(v, torch.Tensor)
# False for post-0.4
def tensor_is_float_tensor():
x = torch.Tensor()
return type(x) == torch.FloatTensor
# Akin to `torch.is_tensor`, but returns True for Variable
# objects in pre-0.4.
def is_tensor_like(x):
return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
# Wraps `torch.is_floating_point` if present, otherwise checks
# the suffix of `x.type()`.
def is_floating_point(x):
if hasattr(torch, 'is_floating_point'):
return torch.is_floating_point(x)
try:
torch_type = x.type()
return torch_type.endswith('FloatTensor') or \
torch_type.endswith('HalfTensor') or \
torch_type.endswith('DoubleTensor')
except AttributeError:
return False
def scalar_python_val(x):
if hasattr(x, 'item'):
return x.item()
else:
if isinstance(x, torch.autograd.Variable):
return x.data[0]
else:
return x[0]
import contextlib
import logging
import warnings
import torch
from ._C import scale_lib
class AmpHandle(object):
def __init__(self, enable_caching=True):
self._enable_caching = enable_caching
self._cache = dict()
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
yield loss * self._loss_scale
loss.backward = loss_backward
self._overflow_buf.zero_()
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
scale_lib.scale_check_overflow(p.grad.data,
1. / self._loss_scale,
self._overflow_buf)
if self._overflow_buf.any():
self._loss_scale /= 2.
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
self._unskipped = 0
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
self._clear_cache()
def _clear_cache(self):
self._cache.clear()
@property
def has_cache(self):
return self._enable_caching
@property
def cache(self):
return self._cache
# TODO: think about the following two. They do weird things.
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
# - torch.nn.utils.weight_norm
# Notes:
# F.instance_norm uses batch_norm internally. Which correctly handles
# fp16 in/out with fp32 weights. So we shouldn't do anything for
# either of these.
# F.normalize calls `input.norm()` internally, so it's redundant, but
# kept here in case impl. changes.
# F.cosine_similarity is same: calls `x.norm()` internally.
import torch.nn.functional
MODULE = torch.nn.functional
FP16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [
# Pointwise
'softplus',
'softmin',
'log_softmax',
'softmax',
# Normalization
'layer_norm',
'group_norm',
'local_response_norm',
'normalize',
'cosine_similarity',
# Loss functions
# TODO: which of these can be fp16?
'binary_cross_entropy',
'poisson_nll_loss',
'cosine_embedding_loss',
'cross_entropy',
'hinge_embedding_loss',
'kl_div',
'l1_loss',
'mse_loss',
'margin_ranking_loss',
'multilabel_margin_loss',
'multilabel_soft_margin_loss',
'multi_margin_loss',
'nll_loss',
'binary_cross_entropy_with_logits',
'smooth_l1_loss',
'soft_margin_loss',
'triplet_margin_loss'
]
from .. import compat
from . import torch_overrides
import importlib
import torch
if compat.variable_is_tensor():
MODULE = torch.Tensor
else:
MODULE = torch.autograd.Variable
FP16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
]
CASTS = [
'__add__',
'__div__',
'__eq__',
'__ge__',
'__gt__',
'__iadd__',
'__idiv__',
'__imul__',
'__isub__',
'__itruediv__',
'__le__',
'__lt__',
'__mul__',
'__ne__',
'__radd__',
'__rdiv__',
'__rmul__',
'__rsub__',
'__rtruediv__',
'__sub__',
'__truediv__',
]
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
# We need to grab all the methods from torch_overrides and add them to
# the Tensor lists as well, as almost all methods are duplicated
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn):
lst.append(fn)
import torch
MODULE = torch
FP16_FUNCS = [
# Math
# TODO: why are these in top-level torch namespace?
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
# TODO: ban in-place versions of these in fp16
FP32_FUNCS = [
# Pointwise
'acos',
'asin',
'cosh',
'erfinv',
'exp',
'expm1',
'log',
'log10',
'log2',
'reciprocal',
'rsqrt',
'sinh',
'tan',
# Other math
'pow',
# Reduction
'cumprod',
'cumsum',
'dist',
'mean',
'norm',
'prod',
'std',
'sum',
'var',
# Special reduction-like BLAS
'addbmm',
'baddbmm',
'bmm',
# Misc
'renorm'
]
# Multi-tensor fns that may need type promotion
CASTS = [
# Multi-tensor math
'addcdiv',
'addcmul',
'atan2',
'cross',
# Element-wise _or_ tensor-wise math
'add',
'div',
'mul',
# Comparison
'eq',
'equal',
'ge',
'gt',
'le',
'lt',
'ne'
]
# Will possibly need to promote *all* elements of `seq`
SEQUENCE_CASTS = [
'cat', # torch.cat(seq, dim=0, out=None)
'stack' # torch.stack(seq, dim=0, out=None)
]
#include <THC/THC.h>
#include "scale_kernel.h"
extern THCState *state;
void scale_check_overflow(THCudaTensor *grads,
float scale,
THCudaByteTensor *overflow_buf) {
size_t num_elems = THCudaTensor_nElement(state, grads);
float *d_grads = THCudaTensor_data(state, grads);
size_t buf_elems = THCudaByteTensor_nElement(state, overflow_buf);
uint8_t *d_overflow_buf = THCudaByteTensor_data(state, overflow_buf);
scale_check_overflow_kernel(state, d_grads, num_elems, scale,
d_overflow_buf, buf_elems);
}
void scale_check_overflow(THCudaTensor *grads,
float scale,
THCudaByteTensor *overflow_buf);
#ifndef SCALE_KERNEL_H
#define SCALE_KERNEL_H
#include <THC/THC.h>
#ifdef __cplusplus
extern "C" {
#endif
void scale_check_overflow_kernel(THCState *state,
float *d_grads, size_t n, float scale,
uint8_t *d_buf, size_t buf_n);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // SCALE_KERNEL_H
from . import compat
import functools
import itertools
import torch
def is_fp_tensor(x):
if is_nested(x):
# Fast-fail version of all(is_fp_tensor)
for y in x:
if not is_fp_tensor(y):
return False
return True
return compat.is_tensor_like(x) and compat.is_floating_point(x)
def is_nested(x):
return isinstance(x, tuple) or isinstance(x, list)
def should_cache(x):
if is_nested(x):
# Fast-fail version of all(should_cache)
for y in x:
if not should_cache(y):
return False
return True
return isinstance(x, torch.nn.parameter.Parameter) and \
type_string(x) == 'FloatTensor'
def collect_fp_tensor_types(args, kwargs):
def collect_types(x, types):
if is_nested(x):
for y in x:
collect_types(y, types)
else:
types.add(type_string(x))
all_args = itertools.chain(args, kwargs.values())
types = set()
for x in all_args:
if is_fp_tensor(x):
collect_types(x, types)
return types
def type_string(x):
return x.type().split('.')[-1]
def maybe_half(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_half(y) for y in x])
if type_string(x) == 'HalfTensor':
return x
else:
if verbose:
print('Float->Half ({})'.format(name))
return x.half()
def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
if type_string(x) == 'FloatTensor':
return x
else:
if verbose:
print('Half->Float ({})'.format(name))
return x.float()
# NB: returneds casted `args`, mutates `kwargs` in-place
def casted_args(cast_fn, args, kwargs):
new_args = []
for x in args:
if is_fp_tensor(x):
new_args.append(cast_fn(x))
else:
new_args.append(x)
for k in kwargs:
val = kwargs[k]
if is_fp_tensor(val):
kwargs[k] = cast_fn(val)
return new_args
def cached_cast(cast_fn, x, cache):
if is_nested(x):
return type(x)([cached_cast(y) for y in x])
if x in cache:
return cache[x]
casted_x = cast_fn(x)
cache[x] = casted_x
return casted_x
def verbosify(cast_fn, fn_name, verbose):
if verbose:
return functools.partial(cast_fn, name=fn_name, verbose=verbose)
else:
return cast_fn
def as_inplace(fns):
for x in fns:
yield x + '_'
def has_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return fn in mod.function_classes
else:
return hasattr(mod, fn)
def get_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return mod.function_classes[fn]
else:
return getattr(mod, fn)
def set_func(mod, fn, new_fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
mod.function_classes[fn] = new_fn
else:
setattr(mod, fn, new_fn)
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
# to this forward call, not the flat buffer.
# - The ordering of weights in the flat buffer is...idiosyncratic.
# First problem is solved with combination of set_ (to set up
# correct storage) and copy_ (so the fp16 weight derives from the
# fp32 one in autograd.
# Second is solved by doing ptr arithmetic on the fp32 weights
# to derive the correct offset.
#
# TODO: maybe this should actually use
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
# on first iter and cache the right offsets. Ugh.
def synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0][0].data_ptr()
for layer_weights in fp32_weights:
fp16_layer_weights = []
for w_fp32 in layer_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_layer_weights.append(w_fp16)
fp16_weights.append(fp16_layer_weights)
return fp16_weights
from . import compat
from . import utils
import functools
import torch
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
# Should happen only pre-0.4
assert not compat.variable_is_tensor()
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if try_caching and handle.has_cache:
args = list(args)
for i in range(len(args)):
if utils.should_cache(args[i]):
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
for k in kwargs:
if utils.should_cache(kwargs[k]):
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(maybe_float,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
else:
raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}'
.format(types))
utils.set_func(mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(seq, *args, **kwargs):
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
else:
# TODO: other mixed-type cases aren't due to autohalf.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper)
def promote_match_arg0(mod, fn, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float
else:
return orig_fn(arg0, *args, **kwargs)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
def err_if_any_half(mod, fn):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
def err_if_arg0_half(mod, fn, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
# - Will call into either CudnnRNN or AutogradRNN
# - Each of these are factory functions that return a per-iter
# `forward` function
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
flat_weight = kwargs.get('flat_weight')
if flat_weight is not None:
# We replace `flat_weight` with an uninitialized fp16
# Tensor. The "actual" weight tensors (provided in `forward`),
# will then be set up as ptrs into the buffer and have the
# corresponding fp32 values copied in.
# We need to call `copy` on the "actual" weights so that the
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert utils.type_string(flat_weight) == 'FloatTensor'
if compat.tensor_is_float_tensor():
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
else:
flat_weight_fp16 = torch.empty_like(flat_weight,
dtype=torch.float16)
kwargs['flat_weight'] = flat_weight_fp16
else:
flat_weight_fp16 = None
forward = orig_rnn(*args, **kwargs)
@functools.wraps(forward)
def fwd_wrapper(*fargs, **fkwargs):
assert len(fargs) == 3 or len(fargs) == 4
inputs, weights, hiddens = fargs[:3]
assert utils.is_fp_tensor(inputs)
assert isinstance(weights, list)
cast_fn = utils.verbosify(utils.maybe_half,
fn,
verbose)
new_args = []
# 0) Inputs
new_args.append(cast_fn(inputs))
# 1) Weights
if flat_weight_fp16 is not None:
fp16_weights = utils.synthesize_flattened_rnn_weights(
weights, flat_weight_fp16, fn, verbose)
else:
fp16_weights = [[cast_fn(w) for w in layer]
for layer in weights]
new_args.append(fp16_weights)
# 2) Inputs: either a tuple (for LSTM) or single tensor
if isinstance(hiddens, tuple):
new_args.append(tuple(cast_fn(x) for x in hiddens))
elif utils.is_fp_tensor(hidden):
new_args.append(cast_fn(hidden))
else:
# Hidden can, in principle, be `None` -- pass through
new_args.append(hidden)
# 3) Batch sizes (0.4 or later only)
if len(fargs) == 4:
new_args.append(fargs[3])
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper)
# This file contains the cffi-extension call to build the custom
# kernel used by amp.
# For mysterious reasons, it needs to live at the top-level directory.
# TODO: remove this when we move to cpp-extension.
import os
import torch
from torch.utils.ffi import create_extension
assert torch.cuda.is_available()
abs_path = os.path.dirname(os.path.realpath(__file__))
sources = ['apex/amp/src/scale_cuda.c']
headers = ['apex/amp/src/scale_cuda.h']
defines = [('WITH_CUDA', None)]
with_cuda = True
extra_objects = [os.path.join(abs_path, 'build/scale_kernel.o')]
# When running `python build_cffi.py` directly, set package=False. But
# if it's used with `cffi_modules` in setup.py, then set package=True.
package = (__name__ != '__main__')
extension = create_extension(
'apex.amp._C.scale_lib',
package=package,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)
if __name__ == '__main__':
extension.build()
#include "scale_kernel.h"
#include <assert.h>
#define BLOCK_SIZE 1024
#define MAX_BLOCKS 1024
#ifdef __cplusplus
extern "C" {
#endif
__global__
void scale_reduce_overflow(float *in, size_t n, float scale,
uint8_t *overflow_out) {
__shared__ uint8_t cta_overflow[BLOCK_SIZE];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
uint8_t my_overflow = 0;
for (int i = tid * 4; i < n; i+= stride * 4) {
if (i < (n - 3)) {
float4 f4 = ((float4*)in)[i / 4];
if (isfinite(f4.x)) {
f4.x *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.y)) {
f4.y *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.z)) {
f4.z *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.w)) {
f4.w *= scale;
} else {
my_overflow = 1;
}
((float4*)in)[i / 4] = f4;
} else {
for (; i < n; ++i) {
if (isfinite(in[i])) {
in[i] *= scale;
} else {
my_overflow = 1;
}
}
}
}
int tIdx = threadIdx.x;
cta_overflow[tIdx] = my_overflow;
__syncthreads();
int participating = BLOCK_SIZE / 2;
while (participating > 0) {
if (tIdx < participating) {
cta_overflow[tIdx] = max(cta_overflow[tIdx],
cta_overflow[tIdx + participating]);
}
participating /= 2;
__syncthreads();
}
if (tIdx == 0) {
overflow_out[blockIdx.x] = max(cta_overflow[0],
overflow_out[blockIdx.x]);
}
}
void scale_check_overflow_kernel(THCState *state,
float *d_grads, size_t n, float scale,
uint8_t *d_buf, size_t buf_n) {
int num_blks = min((int(n) + BLOCK_SIZE - 1) / BLOCK_SIZE,
MAX_BLOCKS);
assert(buf_n >= num_blks);
cudaStream_t cur_stream = THCState_getCurrentStream(state);
scale_reduce_overflow<<<num_blks, BLOCK_SIZE, 0, cur_stream>>>(
d_grads, n, scale, d_buf);
}
#ifdef __cplusplus
} // extern "C"
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment