Commit ec431d33 authored by Carl Case's avatar Carl Case
Browse files

WIP: better annotation / user function registry support

parent 614b11ff
from .amp import build, register_half, register_float, register_promote from .amp import build, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
...@@ -2,37 +2,67 @@ from . import compat, utils, wrap ...@@ -2,37 +2,67 @@ from . import compat, utils, wrap
from .handle import AmpHandle, NoOpHandle from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect import functools
import itertools import itertools
import torch import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set() _USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set() _USER_PROMOTE_REGISTRY = set()
# Can be used as a @decorator directly on the fn def _decorator_helper(orig_fn, cast_fn, wrap_fn):
# or called w/ arg by user before `build()` def wrapper(*args, **kwargs):
def register_half(fn): handle = _DECORATOR_HANDLE
mod = inspect.getmodule(fn) if handle is None or not handle.is_active():
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_half)) return orig_fn(*args, **kwargs)
return fn inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
def register_float(fn): return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
mod = inspect.getmodule(fn) return wrapper
_USER_CAST_REGISTRY.add((mod, fn.__name__, utils.maybe_float))
return fn # Decorator form
def half_function(fn):
def register_promote(fn): wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
mod = inspect.getmodule(fn) return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((mod, fn.__name__)) _USER_PROMOTE_REGISTRY.add((mod, fn.__name__))
return fn
# Top-level function to insert _all_ the hooks. # Top-level function to insert _all_ the hooks.
def build(enabled=True, enable_caching=True, verbose=False): def build(enabled=True, enable_caching=True, verbose=False):
global _DECORATOR_HANDLE
if not enabled: if not enabled:
return NoOpHandle() handle = NoOpHandle()
_DECORATOR_HANDLE = handle
return handle
handle = AmpHandle(enable_caching) handle = AmpHandle(enable_caching, verbose)
# 0) Force-{fp16, fp32} for user-annotated functions # 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_CAST_REGISTRY: for mod, fn, cast_fn in _USER_CAST_REGISTRY:
...@@ -115,4 +145,5 @@ def build(enabled=True, enable_caching=True, verbose=False): ...@@ -115,4 +145,5 @@ def build(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend # 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose) wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
_DECORATOR_HANDLE = handle
return handle return handle
...@@ -6,8 +6,9 @@ from .opt import OptimWrapper ...@@ -6,8 +6,9 @@ from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
class AmpHandle(object): class AmpHandle(object):
def __init__(self, enable_caching=True): def __init__(self, enable_caching=True, verbose=False):
self._enable_caching = enable_caching self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict() self._cache = dict()
self._default_scaler = LossScaler() self._default_scaler = LossScaler()
...@@ -67,6 +68,10 @@ class AmpHandle(object): ...@@ -67,6 +68,10 @@ class AmpHandle(object):
if self.has_cache and param in self.cache: if self.has_cache and param in self.cache:
del self.cache[param] del self.cache[param]
@property
def verbose(self):
return self._verbose
class NoOpHandle(object): class NoOpHandle(object):
def is_active(self): def is_active(self):
return False return False
...@@ -81,3 +86,7 @@ class NoOpHandle(object): ...@@ -81,3 +86,7 @@ class NoOpHandle(object):
@property @property
def has_cache(self): def has_cache(self):
return False return False
@property
def verbose(self):
return False
...@@ -5,13 +5,8 @@ import functools ...@@ -5,13 +5,8 @@ import functools
import torch import torch
def cached_cast(mod, fn, cast_fn, handle, def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False, verbose=False): try_caching=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if try_caching and handle.has_cache: if try_caching and handle.has_cache:
...@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle, ...@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args, args,
kwargs) kwargs)
return orig_fn(*new_args, **kwargs) return orig_fn(*new_args, **kwargs)
utils.set_func(mod, fn, wrapper) return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1: if len(types) <= 1:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(maybe_float, new_args = utils.casted_args(cast_fn,
args, args,
kwargs) kwargs)
return orig_fn(*new_args, **kwargs) return orig_fn(*new_args, **kwargs)
...@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False): ...@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise NotImplementedError('Do not know how to handle ' + raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}' 'these types to promote: {}'
.format(types)) .format(types))
return wrapper
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper) utils.set_func(mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False): def sequence_promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment