Commit d6db91a4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating latest amp changes to use new C++ backend

parents 89564b69 db6ae13a
apex.egg-info
dist
build
docs/build
\ No newline at end of file
docs/build
*~
\ No newline at end of file
......@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started
In the normal case, using amp requires adding two lines of code (and
In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
......@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp
```python
from apex import amp
amp_handle = amp.enable()
amp_handle = amp.init()
```
`amp.enable()` takes two arguments, and the defaults are _highly_
recommended. The first, `enable_caching` (default=True), indicates
whether amp should cache fp16 casts of model parameters on a
per-iteration basis. This prevents things like RNN cells used inside a
loop from casting their weight matrices over and over. The second,
`verbose` (default=False) toggles whether to print out every cast that
occurs. Useful for debugging, mostly.
`amp.init()` takes three (optional) arguments. The most useful is
`enabled` (default=True), which simplifies command-line arguments. If
False, then everything amp does will be a zero-overhead pass-through
-- i.e., your code will run as-is.
For the other two options, the defaults are _highly_ recommended. The
first, `enable_caching` (default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second, `verbose` (default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loops that looks like:
Nearly all PyTorch training scripts have a loop that looks like:
```python
# ... do a bunch of stuff to compute a loss
......@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single `loss.backward()` for each iteration. Some
models are more complex with:
- Multiple optimizer objects (over different parameters)
- Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```python
optimizer = # ... some optimizer
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer)
```
Second, use `optimizer.scale_loss(...)` to indicate where backprop
occurs:
```python
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ...
```
In essence, `amp_handle.scale_loss(loss, optimizer)` is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)` in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`, so it is possible to perform multiple backward passes
before making a parameter update:
```python
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
loss1.backward()
# ...
loss2 = ComputeLoss2(model)
loss2.backward()
# ...
optimizer.step() # has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument `num_loss`
to work with code like this:
```python
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer, num_loss=2)
# ...
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
with optimizer.scale_loss(loss1) as scaled_loss:
scaled_loss.backward()
# ...
loss2 = ComputeLoss2(model)
with optimizer.scale_loss(loss2) as scaled_loss:
scaled_loss.backward()
# ...
optimizer.step()
```
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than steps one and two
Nearly all PyTorch user code needs nothing more than the two steps
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
......@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
```python
from backend import FRUBackend
def fru(input, hidden, weight, bias):
# ... call to CUDA code
# call to CUDA code
FRUBackend(input, hidden, weight, bias)
```
amp exposes two functions to handle this case: `register_fp16` and
`register_fp32`. These add the given function to the white or
blacklist, respectively. You can use them as a decorator:
In this case, it is possible to get a runtime type mismatch. For
example, you might have `input` in fp16, and `weight` in fp32, and amp
doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`
These correspond to:
- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```python
@amp.register_fp16
@amp.half_function
def fru(input, hidden, weight, bias):
# ...
#...
```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`
When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`.
For our FRU unit, we can register the backend function directly:
```python
from apex import amp
amp.register_fp16(custom_module.fru)
amp.enable()
```
import backend
Note that the function must be registered before the call to
`amp.enable()`. The library call makes this simple. If the function is
annotated, then you must ensure its module is loaded before the call
to `amp.enable()`. Furthermore, this does not (yet) work with class
methods, only free functions.
amp.register_half_function(backend, 'FRUBackend')
amp.init()
```
from .amp import enable, register_half, register_float
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from . import compat, utils, wrap
from .handle import AmpHandle
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect
import functools
import itertools
import torch
_USER_REGISTRY = set()
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return orig_fn(*args, **kwargs)
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((module, name))
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `enable()`
def register_half(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
return fn
# Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
def register_float(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
return fn
if not enabled:
handle = NoOpHandle()
_DECORATOR_HANDLE = handle
return handle
# Top-level function to insert _all_ the hooks.
def enable(enable_caching=True, verbose=False):
handle = AmpHandle(enable_caching)
handle = AmpHandle(enable_caching, verbose)
# 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_REGISTRY:
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose)
_USER_REGISTRY.clear()
_USER_CAST_REGISTRY.clear()
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose)
_USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
......@@ -101,4 +145,10 @@ def enable(enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
_DECORATOR_HANDLE = handle
return handle
......@@ -2,54 +2,54 @@ import contextlib
import logging
import warnings
import torch
from apex_C import scale_check_overflow
from .opt import OptimWrapper
from .scaler import LossScaler
class AmpHandle(object):
def __init__(self, enable_caching=True):
def __init__(self, enable_caching=True, verbose=False):
self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict()
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
self._default_scaler = LossScaler()
def is_active(self):
return True
def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
if not self.is_active():
yield loss
return
if self._default_scaler is None:
raise RuntimeError(
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
yield loss * self._loss_scale
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._overflow_buf.zero_()
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
scale_check_overflow(p.grad.data,
1. / self._loss_scale,
self._overflow_buf)
if self._overflow_buf.any():
self._loss_scale /= 2.
should_skip = self._default_scaler.unscale_and_update(
optimizer.param_groups, loss_scale)
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
self._unskipped = 0
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
self._clear_cache()
......@@ -63,3 +63,30 @@ class AmpHandle(object):
@property
def cache(self):
return self._cache
def remove_cache(self, param):
if self.has_cache and param in self.cache:
del self.cache[param]
@property
def verbose(self):
return self._verbose
class NoOpHandle(object):
def is_active(self):
return False
def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
yield loss
@property
def has_cache(self):
return False
@property
def verbose(self):
return False
......@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions
# TODO: which of these can be fp16?
'binary_cross_entropy',
'poisson_nll_loss',
'cosine_embedding_loss',
'cross_entropy',
......@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss',
'triplet_margin_loss'
]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
import contextlib
import logging
import warnings
from .scaler import LossScaler, iter_params
import numpy as np
class OptimWrapper(object):
def __init__(self, optimizer, amp_handle, num_loss):
self._optimizer = optimizer
self._amp_handle = amp_handle
self._num_loss = num_loss
self._loss_idx = 0
self._skip_next = [False] * num_loss
self._loss_scaler = [LossScaler() for _ in range(num_loss)]
@contextlib.contextmanager
def scale_loss(self, loss):
if not self._amp_handle.is_active():
yield loss
return
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads = []
if self._loss_idx > 0:
for p in iter_params(self._optimizer.param_groups):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
self._loss_idx += 1
if len(cached_grads) > 0:
for p, cached_grad in zip(iter_params(self._optimizer.param_groups),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
return self._loss_scaler[self._loss_idx]
def step(self, closure=None):
if not self._amp_handle.is_active():
return self._optimizer.step(closure=closure)
self._loss_idx = 0
for group in self._optimizer.param_groups:
for p in group['params']:
self._amp_handle.remove_cache(p)
if closure is not None:
raise NotImplementedError(
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self):
return self._optimizer.__setstate__()
def __repr__(self):
return self._optimizer.__repr__()
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
def zero_grad(self):
return self._optimizer.zero_grad()
def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
import torch
from apex_C import scale_check_overflow
class LossScaler(object):
def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
def loss_scale(self):
return self._loss_scale
def unscale_and_update(self, param_groups, scale):
self._overflow_buf.zero_()
for p in iter_params(param_groups):
if p.grad is not None:
scale_check_overflow(p.grad.data,
1. / scale,
self._overflow_buf)
if self._overflow_buf.any():
should_skip = True
self._loss_scale /= 2.
self._unskipped = 0
else:
should_skip = False
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
return should_skip
def iter_params(param_groups):
for group in param_groups:
for p in group['params']:
yield p
......@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
if is_nested(x):
return type(x)([cached_cast(y) for y in x])
if x in cache:
cached_x = cache[x]
# During eval, it's possible to end up caching casted weights
# with requires_grad == False. This is then a problem when they
# get reused on the next train iter. So we ensure that cached
# weights have same requires_grad flag of most recent request.
if x.requires_grad != cached_x.requires_grad:
cached_x.requires_grad_(x.requires_grad)
return cache[x]
casted_x = cast_fn(x)
cache[x] = casted_x
return casted_x
......
......@@ -5,13 +5,8 @@ import functools
import torch
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if try_caching and handle.has_cache:
......@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(maybe_float,
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
......@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}'
.format(types))
return wrapper
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
......@@ -84,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
def err_if_any_half(mod, fn):
def err_if_any_half(mod, fn, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
......@@ -93,8 +103,11 @@ def err_if_any_half(mod, fn):
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment