Commit 9cc74429 authored by Carl Case's avatar Carl Case
Browse files

Optimizer wrapper; loss scaling class; no-op handle; start multi-loss

parent ea93767d
apex.egg-info
dist
build
docs/build
\ No newline at end of file
docs/build
*~
\ No newline at end of file
from .amp import enable, register_half, register_float
from .amp import build, register_half, register_float, register_promote
from . import compat, utils, wrap
from .handle import AmpHandle
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect
......@@ -7,30 +7,44 @@ import itertools
import torch
_USER_REGISTRY = set()
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `enable()`
# or called w/ arg by user before `build()`
def register_half(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
return fn
def register_float(fn):
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
return fn
def register_promote(fn):
mod = inspect.getmodule(fn)
_USER_PROMOTE_REGISTRY.add((mod, fn.__name__))
return fn
# Top-level function to insert _all_ the hooks.
def enable(enable_caching=True, verbose=False):
def build(enabled=True, enable_caching=True, verbose=False):
if not enabled:
return NoOpHandle()
handle = AmpHandle(enable_caching)
# 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_REGISTRY:
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose)
_USER_REGISTRY.clear()
_USER_CAST_REGISTRY.clear()
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose)
_USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
......
......@@ -2,54 +2,53 @@ import contextlib
import logging
import warnings
import torch
from ._C import scale_lib
from .opt import OptimWrapper
from .scaler import LossScaler
class AmpHandle(object):
def __init__(self, enable_caching=True):
self._enable_caching = enable_caching
self._cache = dict()
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
self._default_scaler = LossScaler()
def is_active(self):
return True
def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
if not self.is_active():
yield loss
return
if self._default_scaler is None:
raise RuntimeError(
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
yield loss * self._loss_scale
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._overflow_buf.zero_()
for group in optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
scale_lib.scale_check_overflow(p.grad.data,
1. / self._loss_scale,
self._overflow_buf)
if self._overflow_buf.any():
self._loss_scale /= 2.
should_skip = self._default_scaler.unscale_and_update(
optimizer.param_groups, loss_scale)
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
self._unskipped = 0
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
self._clear_cache()
......@@ -63,3 +62,22 @@ class AmpHandle(object):
@property
def cache(self):
return self._cache
def remove_cache(self, param):
if self.has_cache and param in self.cache:
del self.cache[param]
class NoOpHandle(object):
def is_active(self):
return False
def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
yield loss
@property
def has_cache(self):
return False
import contextlib
import warnings
from .scaler import LossScaler
import numpy as np
class OptimWrapper(object):
def __init__(self, optimizer, amp_handle, num_loss):
self._optimizer = optimizer
self._amp_handle = amp_handle
self._num_loss = num_loss
self._loss_idx = 0
self._skip_next = [False] * num_loss
self._loss_scaler = [LossScaler() for _ in range(num_loss)]
@contextlib.contextmanager
def scale_loss(self, loss):
if not self._amp_handle.is_active():
yield loss
return
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
# if loss_idx > 0:
# save out current grads to buffers
# keep some group caches
# .detach().clone()
# zero grads
loss_scale = self._cur_loss_scaler().loss_scale()
print('Loss scale (log): {}'.format(np.log2(loss_scale)))
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
print('GOT SKIP NEXT: {}'.format(self._skip_next[self._loss_idx]))
self._loss_idx += 1
# if loss_idx > 0:
# += saved state into grads
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
return self._loss_scaler[self._loss_idx]
def step(self, closure=None):
if not self._amp_handle.is_active():
return self._optimizer.step(closure=closure)
self._loss_idx = 0
for group in self._optimizer.param_groups:
for p in group['params']:
self._amp_handle.remove_cache(p)
if closure is not None:
raise NotImplementedError(
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
self._skip_next = [False] * self._num_loss
print('SKIP')
else:
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self):
return self._optimizer.__setstate__()
def __repr__(self):
return self._optimizer.__repr__()
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
def zero_grad(self):
return self._optimizer.zero_grad()
def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
import torch
from ._C import scale_lib
class LossScaler(object):
def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
def loss_scale(self):
return self._loss_scale
def unscale_and_update(self, param_groups, scale):
self._overflow_buf.zero_()
for group in param_groups:
for p in group['params']:
if p.grad is not None:
scale_lib.scale_check_overflow(p.grad.data,
1. / scale,
self._overflow_buf)
if self._overflow_buf.any():
should_skip = True
self._loss_scale /= 2.
self._unskipped = 0
else:
should_skip = False
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
return should_skip
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment