Commit d6db91a4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating latest amp changes to use new C++ backend

parents 89564b69 db6ae13a
apex.egg-info apex.egg-info
dist dist
build build
docs/build docs/build
\ No newline at end of file *~
\ No newline at end of file
...@@ -41,7 +41,7 @@ top-level README for more on installation. ...@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started ## Usage and Getting Started
In the normal case, using amp requires adding two lines of code (and In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal occurs so that it can properly scale the loss and clear internal
...@@ -50,20 +50,25 @@ per-iteration state. ...@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp #### 1. Enable amp
```python ```python
from apex import amp from apex import amp
amp_handle = amp.enable() amp_handle = amp.init()
``` ```
`amp.enable()` takes two arguments, and the defaults are _highly_ `amp.init()` takes three (optional) arguments. The most useful is
recommended. The first, `enable_caching` (default=True), indicates `enabled` (default=True), which simplifies command-line arguments. If
whether amp should cache fp16 casts of model parameters on a False, then everything amp does will be a zero-overhead pass-through
per-iteration basis. This prevents things like RNN cells used inside a -- i.e., your code will run as-is.
loop from casting their weight matrices over and over. The second,
`verbose` (default=False) toggles whether to print out every cast that For the other two options, the defaults are _highly_ recommended. The
occurs. Useful for debugging, mostly. first, `enable_caching` (default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second, `verbose` (default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation #### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loops that looks like: Nearly all PyTorch training scripts have a loop that looks like:
```python ```python
# ... do a bunch of stuff to compute a loss # ... do a bunch of stuff to compute a loss
...@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to ...@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache `enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.) after each optimizer step with `amp_handle._clear_cache()`.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single `loss.backward()` for each iteration. Some
models are more complex with:
- Multiple optimizer objects (over different parameters)
- Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```python
optimizer = # ... some optimizer
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer)
```
Second, use `optimizer.scale_loss(...)` to indicate where backprop
occurs:
```python
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ...
```
In essence, `amp_handle.scale_loss(loss, optimizer)` is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)` in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`, so it is possible to perform multiple backward passes
before making a parameter update:
```python
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
loss1.backward()
# ...
loss2 = ComputeLoss2(model)
loss2.backward()
# ...
optimizer.step() # has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument `num_loss`
to work with code like this:
```python
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer, num_loss=2)
# ...
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
with optimizer.scale_loss(loss1) as scaled_loss:
scaled_loss.backward()
# ...
loss2 = ComputeLoss2(model)
with optimizer.scale_loss(loss2) as scaled_loss:
scaled_loss.backward()
# ...
optimizer.step()
```
## Annotating User Functions ## Annotating User Functions
Nearly all PyTorch user code needs nothing more than steps one and two Nearly all PyTorch user code needs nothing more than the two steps
above to use amp. After all, custom layers are built out of simpler above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those. PyTorch components, and amp already can see those.
...@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a ...@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend: CUDA backend:
```python ```python
from backend import FRUBackend
def fru(input, hidden, weight, bias): def fru(input, hidden, weight, bias):
# ... call to CUDA code # call to CUDA code
FRUBackend(input, hidden, weight, bias)
``` ```
amp exposes two functions to handle this case: `register_fp16` and In this case, it is possible to get a runtime type mismatch. For
`register_fp32`. These add the given function to the white or example, you might have `input` in fp16, and `weight` in fp32, and amp
blacklist, respectively. You can use them as a decorator: doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`
These correspond to:
- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```python ```python
@amp.register_fp16 @amp.half_function
def fru(input, hidden, weight, bias): def fru(input, hidden, weight, bias):
# ... #...
``` ```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`
When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`.
For our FRU unit, we can register the backend function directly:
```python ```python
from apex import amp import backend
amp.register_fp16(custom_module.fru)
amp.enable()
```
Note that the function must be registered before the call to amp.register_half_function(backend, 'FRUBackend')
`amp.enable()`. The library call makes this simple. If the function is amp.init()
annotated, then you must ensure its module is loaded before the call ```
to `amp.enable()`. Furthermore, this does not (yet) work with class
methods, only free functions.
from .amp import enable, register_half, register_float from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from . import compat, utils, wrap from . import compat, utils, wrap
from .handle import AmpHandle from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides from .lists import functional_overrides, torch_overrides, tensor_overrides
import inspect import functools
import itertools import itertools
import torch import torch
_USER_REGISTRY = set() _DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return orig_fn(*args, **kwargs)
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((module, name))
# Can be used as a @decorator directly on the fn # Top-level function to insert _all_ the hooks.
# or called w/ arg by user before `enable()` def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
def register_half(fn): global _DECORATOR_HANDLE
mod = inspect.getmodule(fn)
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_half))
return fn
def register_float(fn): if not enabled:
mod = inspect.getmodule(fn) handle = NoOpHandle()
_USER_REGISTRY.add((mod, fn.__name__, utils.maybe_float)) _DECORATOR_HANDLE = handle
return fn return handle
# Top-level function to insert _all_ the hooks. handle = AmpHandle(enable_caching, verbose)
def enable(enable_caching=True, verbose=False):
handle = AmpHandle(enable_caching)
# 0) Force-{fp16, fp32} for user-annotated functions # 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_REGISTRY: for mod, fn, cast_fn in _USER_CAST_REGISTRY:
try_caching = (cast_fn == utils.maybe_half) try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle, wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose) try_caching, verbose)
_USER_REGISTRY.clear() _USER_CAST_REGISTRY.clear()
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose)
_USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions # 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides, override_modules = [functional_overrides,
...@@ -101,4 +145,10 @@ def enable(enable_caching=True, verbose=False): ...@@ -101,4 +145,10 @@ def enable(enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend # 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose) wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
_DECORATOR_HANDLE = handle
return handle return handle
...@@ -2,54 +2,54 @@ import contextlib ...@@ -2,54 +2,54 @@ import contextlib
import logging import logging
import warnings import warnings
import torch from .opt import OptimWrapper
from .scaler import LossScaler
from apex_C import scale_check_overflow
class AmpHandle(object): class AmpHandle(object):
def __init__(self, enable_caching=True): def __init__(self, enable_caching=True, verbose=False):
self._enable_caching = enable_caching self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict() self._cache = dict()
self._loss_scale = 2.**16 self._default_scaler = LossScaler()
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000 def is_active(self):
self._unskipped = 0 return True
self._overflow_buf = torch.cuda.ByteTensor(1024,)
def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager @contextlib.contextmanager
def scale_loss(self, loss, optimizer): def scale_loss(self, loss, optimizer):
if not self.is_active():
yield loss
return
if self._default_scaler is None:
raise RuntimeError(
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward loss_backward = loss.backward
def warning_wrapper(): def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss " warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost " "inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2) "certainly an error.", stacklevel=2)
loss_backward() loss_backward()
loss.backward = warning_wrapper loss.backward = warning_wrapper
yield loss * self._loss_scale loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward loss.backward = loss_backward
self._overflow_buf.zero_() should_skip = self._default_scaler.unscale_and_update(
for group in optimizer.param_groups: optimizer.param_groups, loss_scale)
for p in group['params']: if should_skip:
if p.grad is not None:
scale_check_overflow(p.grad.data,
1. / self._loss_scale,
self._overflow_buf)
if self._overflow_buf.any():
self._loss_scale /= 2.
optimizer_step = optimizer.step optimizer_step = optimizer.step
def skip_step(): def skip_step():
logging.info('Gradient overflow, skipping update') logging.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step optimizer.step = optimizer_step
optimizer.step = skip_step optimizer.step = skip_step
self._unskipped = 0
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
self._clear_cache() self._clear_cache()
...@@ -63,3 +63,30 @@ class AmpHandle(object): ...@@ -63,3 +63,30 @@ class AmpHandle(object):
@property @property
def cache(self): def cache(self):
return self._cache return self._cache
def remove_cache(self, param):
if self.has_cache and param in self.cache:
del self.cache[param]
@property
def verbose(self):
return self._verbose
class NoOpHandle(object):
def is_active(self):
return False
def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
yield loss
@property
def has_cache(self):
return False
@property
def verbose(self):
return False
...@@ -42,7 +42,6 @@ FP32_FUNCS = [ ...@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions # Loss functions
# TODO: which of these can be fp16? # TODO: which of these can be fp16?
'binary_cross_entropy',
'poisson_nll_loss', 'poisson_nll_loss',
'cosine_embedding_loss', 'cosine_embedding_loss',
'cross_entropy', 'cross_entropy',
...@@ -60,3 +59,15 @@ FP32_FUNCS = [ ...@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss', 'soft_margin_loss',
'triplet_margin_loss' 'triplet_margin_loss'
] ]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
import contextlib
import logging
import warnings
from .scaler import LossScaler, iter_params
import numpy as np
class OptimWrapper(object):
def __init__(self, optimizer, amp_handle, num_loss):
self._optimizer = optimizer
self._amp_handle = amp_handle
self._num_loss = num_loss
self._loss_idx = 0
self._skip_next = [False] * num_loss
self._loss_scaler = [LossScaler() for _ in range(num_loss)]
@contextlib.contextmanager
def scale_loss(self, loss):
if not self._amp_handle.is_active():
yield loss
return
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads = []
if self._loss_idx > 0:
for p in iter_params(self._optimizer.param_groups):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
self._loss_idx += 1
if len(cached_grads) > 0:
for p, cached_grad in zip(iter_params(self._optimizer.param_groups),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
return self._loss_scaler[self._loss_idx]
def step(self, closure=None):
if not self._amp_handle.is_active():
return self._optimizer.step(closure=closure)
self._loss_idx = 0
for group in self._optimizer.param_groups:
for p in group['params']:
self._amp_handle.remove_cache(p)
if closure is not None:
raise NotImplementedError(
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self):
return self._optimizer.__setstate__()
def __repr__(self):
return self._optimizer.__repr__()
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
def zero_grad(self):
return self._optimizer.zero_grad()
def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
import torch
from apex_C import scale_check_overflow
class LossScaler(object):
def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._overflow_buf = torch.cuda.ByteTensor(1024,)
def loss_scale(self):
return self._loss_scale
def unscale_and_update(self, param_groups, scale):
self._overflow_buf.zero_()
for p in iter_params(param_groups):
if p.grad is not None:
scale_check_overflow(p.grad.data,
1. / scale,
self._overflow_buf)
if self._overflow_buf.any():
should_skip = True
self._loss_scale /= 2.
self._unskipped = 0
else:
should_skip = False
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
self._unskipped = 0
return should_skip
def iter_params(param_groups):
for group in param_groups:
for p in group['params']:
yield p
...@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache): ...@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
if is_nested(x): if is_nested(x):
return type(x)([cached_cast(y) for y in x]) return type(x)([cached_cast(y) for y in x])
if x in cache: if x in cache:
cached_x = cache[x]
# During eval, it's possible to end up caching casted weights
# with requires_grad == False. This is then a problem when they
# get reused on the next train iter. So we ensure that cached
# weights have same requires_grad flag of most recent request.
if x.requires_grad != cached_x.requires_grad:
cached_x.requires_grad_(x.requires_grad)
return cache[x] return cache[x]
casted_x = cast_fn(x) casted_x = cast_fn(x)
cache[x] = casted_x cache[x] = casted_x
return casted_x return casted_x
......
...@@ -5,13 +5,8 @@ import functools ...@@ -5,13 +5,8 @@ import functools
import torch import torch
def cached_cast(mod, fn, cast_fn, handle, def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False, verbose=False): try_caching=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if try_caching and handle.has_cache: if try_caching and handle.has_cache:
...@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle, ...@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args, args,
kwargs) kwargs)
return orig_fn(*new_args, **kwargs) return orig_fn(*new_args, **kwargs)
utils.set_func(mod, fn, wrapper) return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1: if len(types) <= 1:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(maybe_float, new_args = utils.casted_args(cast_fn,
args, args,
kwargs) kwargs)
return orig_fn(*new_args, **kwargs) return orig_fn(*new_args, **kwargs)
...@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False): ...@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise NotImplementedError('Do not know how to handle ' + raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}' 'these types to promote: {}'
.format(types)) .format(types))
return wrapper
def promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper) utils.set_func(mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False): def sequence_promote(mod, fn, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
...@@ -84,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False): ...@@ -84,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return orig_fn(arg0, *new_args, **kwargs) return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func(mod, fn, wrapper)
def err_if_any_half(mod, fn): def err_if_any_half(mod, fn, custom_err_msg=None):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
...@@ -93,8 +103,11 @@ def err_if_any_half(mod, fn): ...@@ -93,8 +103,11 @@ def err_if_any_half(mod, fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types: if 'HalfTensor' in types:
raise NotImplementedError('Cannot call in-place function ' + if custom_err_msg:
'{} with fp16 arguments.'.format(fn)) raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
else: else:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func(mod, fn, wrapper)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment