Commit 39928327 authored by Carl Case's avatar Carl Case
Browse files

tests on banned methods

parent 22920fe0
...@@ -153,9 +153,13 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -153,9 +153,13 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type) mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle) wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions # 6) Place error+print message on banned functions.
if not allow_banned: # Or, if allow_banned, then cast to FP32.
for fn, err_msg in functional_overrides.BANNED_FUNCS: for fn, err_msg in functional_overrides.BANNED_FUNCS:
if allow_banned:
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
handle, try_caching=True, verbose=verbose)
else:
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg) wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
......
...@@ -71,15 +71,33 @@ class TestBasicCasts(unittest.TestCase): ...@@ -71,15 +71,33 @@ class TestBasicCasts(unittest.TestCase):
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
test_backward=False) test_backward=False)
def test_bce_raises(self): class TestBannedMethods(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def bce_common(self, assertion):
shape = (self.b, self.h) shape = (self.b, self.h)
target = torch.randn(shape) target = torch.rand(shape)
mod = nn.BCELoss() mod = nn.BCELoss()
m = lambda x: mod(x, target) m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target) f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]: for fn in [m, f]:
x = torch.randn(shape, dtype=torch.half) x = torch.rand(shape, dtype=torch.half)
self.assertRaises(NotImplementedError, fn, x) assertion(fn, x)
def test_bce_raises_by_default(self):
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
self.bce_common(assertion)
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)
class TestTensorCasts(unittest.TestCase): class TestTensorCasts(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment