Commit ec431d33 authored by Carl Case's avatar Carl Case
Browse files

WIP: better annotation / user function registry support

parent 614b11ff
from .amp import build, register_half, register_float, register_promote
from .amp import build, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
......@@ -2,37 +2,67 @@ from . import compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect
import functools
import itertools
import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
# Can be used as a @decorator directly on the fn
# or called w/ arg by user before `build()`
def register_half(fn):
mod = inspect.getmodule(fn)
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
return fn
def register_float(fn):
mod = inspect.getmodule(fn)
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
return fn
def register_promote(fn):
mod = inspect.getmodule(fn)
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return orig_fn(*args, **kwargs)
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((mod, fn.__name__))
return fn
# Top-level function to insert _all_ the hooks.
def build(enabled=True, enable_caching=True, verbose=False):
global _DECORATOR_HANDLE
if not enabled:
return NoOpHandle()
handle = NoOpHandle()
_DECORATOR_HANDLE = handle
return handle
handle = AmpHandle(enable_caching)
handle = AmpHandle(enable_caching, verbose)
# 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
......@@ -115,4 +145,5 @@ def build(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
_DECORATOR_HANDLE = handle
return handle
......@@ -6,8 +6,9 @@ from .opt import OptimWrapper
from .scaler import LossScaler
class AmpHandle(object):
def __init__(self, enable_caching=True):
def __init__(self, enable_caching=True, verbose=False):
self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict()
self._default_scaler = LossScaler()
......@@ -67,6 +68,10 @@ class AmpHandle(object):
if self.has_cache and param in self.cache:
del self.cache[param]
@property
def verbose(self):
return self._verbose
class NoOpHandle(object):
def is_active(self):
return False
......@@ -81,3 +86,7 @@ class NoOpHandle(object):
@property
def has_cache(self):
return False
@property
def verbose(self):
return False
......@@ -5,13 +5,8 @@ import functools
import torch
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if try_caching and handle.has_cache:
......@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(maybe_float,
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
......@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}'
.format(types))
return wrapper
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment