Commit 22920fe0 authored by Carl Case's avatar Carl Case
Browse files

add more promotion testing

parent 81c788f0
...@@ -17,18 +17,27 @@ class TestPromotion(unittest.TestCase): ...@@ -17,18 +17,27 @@ class TestPromotion(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.handle._deactivate() self.handle._deactivate()
def run_binary_promote_test(self, fns, input_shape): def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
type_pairs = it.product(DTYPES, DTYPES) type_pairs = it.product(DTYPES, DTYPES)
for fn, (xtype, ytype) in it.product(fns, type_pairs): for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_() x = torch.randn(input_shape, dtype=xtype).requires_grad_()
x_leaf = x
if x_inplace:
# We need a non-leaf to call in place on
x = x.clone()
y = torch.randn(input_shape, dtype=ytype) y = torch.randn(input_shape, dtype=ytype)
out = fn(x, y) out = fn(x, y)
if xtype == torch.float or ytype == torch.float: if x_inplace:
self.assertEqual(out.type(), FLOAT) # In place: always match xtype
self.assertEqual(out.type(), x.type())
else: else:
self.assertEqual(out.type(), HALF) # Out of place: match widest type
if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT)
else:
self.assertEqual(out.type(), HALF)
out.float().sum().backward() out.float().sum().backward()
self.assertEqual(x.grad.dtype, xtype) self.assertEqual(x_leaf.grad.dtype, xtype)
def test_atan2_matches_widest(self): def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y), fns = [lambda x, y : torch.atan2(x, y),
...@@ -50,9 +59,17 @@ class TestPromotion(unittest.TestCase): ...@@ -50,9 +59,17 @@ class TestPromotion(unittest.TestCase):
out = torch.cat(ys + [x_half]) out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF) self.assertEqual(out.type(), HALF)
# TODOs: def test_inplace_exp_is_error_for_half(self):
# In-place methods on fp16 are errors for fp32 xs = torch.randn(self.b)
# In-place methods match type of self tensor xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=torch.half)
with self.assertRaises(NotImplementedError):
xs.exp_()
def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment