Unverified Commit 1579b9e3 authored by Carl Case's avatar Carl Case Committed by GitHub
Browse files

Merge pull request #40 from NVIDIA/amp_tests

amp unit tests
parents 75a865e3 5f5dfa42
...@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 0.5) Force-promote for user-annotated functions # 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY: for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose) wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear() _USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions # 1) Force-{fp16, fp32} on white- / black-list functions
...@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules, for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
promote_table): promote_table):
for fn in getattr(promote_mod, list_name): for fn in getattr(promote_mod, list_name):
promote_fn(promote_mod.MODULE, fn, verbose) promote_fn(promote_mod.MODULE, fn, handle, verbose)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
...@@ -115,27 +115,27 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -115,27 +115,27 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
torch.cuda.HalfTensor], torch.cuda.HalfTensor],
promote_table): promote_table):
for fn in getattr(tensor_overrides, list_name): for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, verbose) promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16. # 3) For any in-place version of a blacklist function, error if any input is fp16.
# NB: this is overly conservative. # NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS): for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn) wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor # 3.5) For any in-place blacklist method, error if called on fp16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS): for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, verbose) wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, verbose) wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
# 4) For other in-place methods, match the type of self tensor # 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain( for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS, tensor_overrides.FP16_FUNCS,
tensor_overrides.CASTS)): tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, verbose) wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, verbose) wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, verbose) wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) Special handling to whitelist RNN cell backend impls. # 5) Special handling to whitelist RNN cell backend impls.
for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']: for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']:
...@@ -143,7 +143,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -143,7 +143,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
handle, try_caching=True, verbose=verbose) handle, try_caching=True, verbose=verbose)
# 5.5) Extra-special handling of RNN backend # 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose) wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
# And even more special handling of `backward` for fused gru / lstm # And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally, # The `backward` method calls Tensor.sum() (blacklist) internally,
...@@ -153,10 +153,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -153,10 +153,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type) mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle) wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions # 6) Place error+print message on banned functions.
if not allow_banned: # Or, if allow_banned, then cast to FP32.
for fn, err_msg in functional_overrides.BANNED_FUNCS: for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg) if allow_banned:
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
handle, try_caching=True, verbose=verbose)
else:
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
return handle return handle
...@@ -2,6 +2,7 @@ import contextlib ...@@ -2,6 +2,7 @@ import contextlib
import logging import logging
import warnings import warnings
from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
...@@ -12,6 +13,7 @@ class AmpHandle(object): ...@@ -12,6 +13,7 @@ class AmpHandle(object):
self._cache = dict() self._cache = dict()
self._default_scaler = LossScaler() self._default_scaler = LossScaler()
self._is_active = True self._is_active = True
self._all_wrappers = []
def is_active(self): def is_active(self):
return self._is_active return self._is_active
...@@ -63,6 +65,15 @@ class AmpHandle(object): ...@@ -63,6 +65,15 @@ class AmpHandle(object):
def _clear_cache(self): def _clear_cache(self):
self._cache.clear() self._cache.clear()
# Experimental support for saving / restoring uncasted versions of functions
def _save_func(self, mod, fn, func):
self._all_wrappers.append((mod, fn, func))
def _deactivate(self):
for mod, fn, func in self._all_wrappers:
utils.set_func(mod, fn, func)
self._all_wrappers = []
@property @property
def has_cache(self): def has_cache(self):
return self._enable_caching return self._enable_caching
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from .utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()):
x = torch.randn(input_shape, dtype=typ).requires_grad_()
y = fn(x)
test_case.assertEqual(y.type(), expected[typ])
if test_backward:
y.float().sum().backward()
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
class TestBasicCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
def test_conv2d_is_half(self):
m = nn.Conv2d(self.c, self.c, self.k)
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
def test_softmax_is_float(self):
m = nn.Softmax(dim=1)
f = ft.partial(F.softmax, dim=1)
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
def test_group_norm_is_float(self):
m = nn.GroupNorm(num_groups=4, num_channels=self.c)
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
def test_mse_loss_is_float(self):
shape = (self.b, self.h)
target = torch.randn(shape)
mod = nn.MSELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.mse_loss, target=target)
run_layer_test(self, [m], ALWAYS_FLOAT, shape)
def test_relu_is_match(self):
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
def test_batch_norm_is_match(self):
m = nn.BatchNorm2d(num_features=self.c)
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=True)
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
# Test forward-only for BN inference
m.eval()
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=False)
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
test_backward=False)
class TestBannedMethods(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def bce_common(self, assertion):
shape = (self.b, self.h)
target = torch.rand(shape)
mod = nn.BCELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]:
x = torch.rand(shape, dtype=torch.half)
assertion(fn, x)
def test_bce_raises_by_default(self):
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
self.bce_common(assertion)
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)
class TestTensorCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x.matmul(other)
rhs = lambda x: other.matmul(x)
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_matmul_op_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x @ other
rhs = lambda x: other @ x
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_pow_method_is_float(self):
fn = lambda x: x.pow(2.)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_pow_op_is_float(self):
fn = lambda x: x ** 2.
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_cpu_is_float(self):
fn = lambda x: x.cpu()
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.half: 'torch.FloatTensor'}
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
def test_sum_is_float(self):
fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting?
if __name__ == '__main__':
unittest.main()
import unittest
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from .utils import common_init, HALF, FLOAT, DTYPES
class TestPromotion(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
type_pairs = it.product(DTYPES, DTYPES)
for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_()
x_leaf = x
if x_inplace:
# We need a non-leaf to call in place on
x = x.clone()
y = torch.randn(input_shape, dtype=ytype)
out = fn(x, y)
if x_inplace:
# In place: always match xtype
self.assertEqual(out.type(), x.type())
else:
# Out of place: match widest type
if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT)
else:
self.assertEqual(out.type(), HALF)
out.float().sum().backward()
self.assertEqual(x_leaf.grad.dtype, xtype)
def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y),
lambda x, y : x.atan2(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_cat_matches_widest(self):
shape = self.b
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_half = torch.randn(shape, dtype=torch.half)
out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF)
def test_inplace_exp_is_error_for_half(self):
xs = torch.randn(self.b)
xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=torch.half)
with self.assertRaises(NotImplementedError):
xs.exp_()
def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
if __name__ == '__main__':
unittest.main()
import unittest
from apex import amp
import torch
from torch import nn
from .utils import common_init, HALF
class TestRnnCells(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_cell_test(self, cell, state_tuple=False):
shape = (self.b, self.h)
for typ in [torch.float, torch.half]:
xs = [torch.randn(shape, dtype=typ).requires_grad_()
for _ in range(self.t)]
hidden_fn = lambda: torch.zeros(shape, dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
outputs = []
for i in range(self.t):
hidden = cell(xs[i], hidden)
if state_tuple:
output = hidden[0]
else:
output = hidden
outputs.append(output)
for y in outputs:
self.assertEqual(y.type(), HALF)
outputs[-1].float().sum().backward()
for i, x in enumerate(xs):
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_cell_is_half(self):
cell = nn.RNNCell(self.h, self.h)
self.run_cell_test(cell)
def test_gru_cell_is_half(self):
cell = nn.GRUCell(self.h, self.h)
self.run_cell_test(cell)
def test_lstm_cell_is_half(self):
cell = nn.LSTMCell(self.h, self.h)
self.run_cell_test(cell, state_tuple=True)
class TestRnns(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
self.b, self.h), dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
output, _ = rnn(x, hidden)
self.assertEqual(output.type(), HALF)
output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
if __name__ == '__main__':
unittest.main()
import torch
HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
DTYPES = [torch.half, torch.float]
ALWAYS_HALF = {torch.float: HALF,
torch.half: HALF}
ALWAYS_FLOAT = {torch.float: FLOAT,
torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
torch.half: HALF}
def common_init(test_case):
test_case.h = 64
test_case.b = 16
test_case.c = 16
test_case.k = 3
test_case.t = 10
torch.set_default_tensor_type(torch.cuda.FloatTensor)
...@@ -126,6 +126,11 @@ def set_func(mod, fn, new_fn): ...@@ -126,6 +126,11 @@ def set_func(mod, fn, new_fn):
else: else:
setattr(mod, fn, new_fn) setattr(mod, fn, new_fn)
def set_func_save(handle, mod, fn, new_fn):
cur_fn = get_func(mod, fn)
handle._save_func(mod, fn, cur_fn)
set_func(mod, fn, new_fn)
# A couple problems get solved here: # A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph, # - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights # so the fp16 weights need to be derived from the input weights
......
...@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle, ...@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle,
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose) cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching) wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper` # `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None): def make_promote_wrapper(orig_fn, cast_fn, handle=None):
...@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None): ...@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
.format(types)) .format(types))
return wrapper return wrapper
def promote(mod, fn, verbose=False): def promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float) wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False): def sequence_promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
...@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False): ...@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False):
# TODO: other mixed-type cases aren't due to amp. # TODO: other mixed-type cases aren't due to amp.
# Just pass through? # Just pass through?
return orig_fn(seq, *args, **kwargs) return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
def promote_match_arg0(mod, fn, verbose=False): def promote_match_arg0(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
...@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False): ...@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False):
cast_fn = utils.verbosify(cast_fn, fn, verbose) cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs) new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs) return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
def err_if_any_half(mod, fn, custom_err_msg=None): def err_if_any_half(mod, fn, handle, custom_err_msg=None):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
...@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None): ...@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None):
'{} with fp16 arguments.'.format(fn)) '{} with fp16 arguments.'.format(fn))
else: else:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
def err_if_arg0_half(mod, fn, verbose=False): def err_if_arg0_half(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
...@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False): ...@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose) cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs) new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs) return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
# Current RNN approach: # Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend # - Wrap top-level `RNN` function in thnn backend
...@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False): ...@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
# - We interpose on the factory function to: # - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts # 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary # 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, verbose=False): def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn) orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn) @functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs): def rnn_wrapper(*args, **kwargs):
...@@ -203,7 +203,7 @@ def rnn_cast(backend, fn, verbose=False): ...@@ -203,7 +203,7 @@ def rnn_cast(backend, fn, verbose=False):
return forward(*new_args, **fkwargs) return forward(*new_args, **fkwargs)
return fwd_wrapper return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper) utils.set_func_save(handle, backend, fn, rnn_wrapper)
def disable_casts(mod, fn, handle): def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
...@@ -214,4 +214,4 @@ def disable_casts(mod, fn, handle): ...@@ -214,4 +214,4 @@ def disable_casts(mod, fn, handle):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
with handle._disable_casts(): with handle._disable_casts():
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment