Commit 81c788f0 authored by Carl Case's avatar Carl Case
Browse files

WIP: promotion tests

parent 2e69d933
...@@ -8,58 +8,44 @@ import torch ...@@ -8,58 +8,44 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
HALF = 'torch.cuda.HalfTensor' from .utils import common_init, HALF, FLOAT,\
FLOAT = 'torch.cuda.FloatTensor' ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
ALWAYS_HALF = {torch.float: HALF, def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
torch.half: HALF} for fn, typ in it.product(fns, expected.keys()):
ALWAYS_FLOAT = {torch.float: FLOAT, x = torch.randn(input_shape, dtype=typ).requires_grad_()
torch.half: FLOAT} y = fn(x)
MATCH_INPUT = {torch.float: FLOAT, test_case.assertEqual(y.type(), expected[typ])
torch.half: HALF} if test_backward:
y.float().sum().backward()
def _common_init(test_case): test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
test_case.h = 64
test_case.b = 16
test_case.c = 16
test_case.k = 3
torch.set_default_tensor_type(torch.cuda.FloatTensor)
class TestBasicCasts(unittest.TestCase): class TestBasicCasts(unittest.TestCase):
def setUp(self): def setUp(self):
self.handle = amp.init(enabled=True) self.handle = amp.init(enabled=True)
_common_init(self) common_init(self)
def tearDown(self): def tearDown(self):
self.handle._deactivate() self.handle._deactivate()
def run_layer_test(self, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()):
x = torch.randn(input_shape, dtype=typ).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), expected[typ])
if test_backward:
y.float().sum().backward()
self.assertEqual(x.grad.type(), MATCH_INPUT[typ])
def test_linear_is_half(self): def test_linear_is_half(self):
m = nn.Linear(self.h, self.h) m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias) f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
self.run_layer_test([m, f], ALWAYS_HALF, (self.b, self.h)) run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
def test_conv2d_is_half(self): def test_conv2d_is_half(self):
m = nn.Conv2d(self.c, self.c, self.k) m = nn.Conv2d(self.c, self.c, self.k)
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
self.run_layer_test([m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
def test_softmax_is_float(self): def test_softmax_is_float(self):
m = nn.Softmax(dim=1) m = nn.Softmax(dim=1)
f = ft.partial(F.softmax, dim=1) f = ft.partial(F.softmax, dim=1)
self.run_layer_test([m, f], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
def test_group_norm_is_float(self): def test_group_norm_is_float(self):
m = nn.GroupNorm(num_groups=4, num_channels=self.c) m = nn.GroupNorm(num_groups=4, num_channels=self.c)
self.run_layer_test([m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
def test_mse_loss_is_float(self): def test_mse_loss_is_float(self):
shape = (self.b, self.h) shape = (self.b, self.h)
...@@ -67,27 +53,76 @@ class TestBasicCasts(unittest.TestCase): ...@@ -67,27 +53,76 @@ class TestBasicCasts(unittest.TestCase):
mod = nn.MSELoss() mod = nn.MSELoss()
m = lambda x: mod(x, target) m = lambda x: mod(x, target)
f = ft.partial(F.mse_loss, target=target) f = ft.partial(F.mse_loss, target=target)
self.run_layer_test([m], ALWAYS_FLOAT, shape) run_layer_test(self, [m], ALWAYS_FLOAT, shape)
def test_relu_is_match(self): def test_relu_is_match(self):
self.run_layer_test([nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h)) run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
def test_batch_norm_is_match(self): def test_batch_norm_is_match(self):
m = nn.BatchNorm2d(num_features=self.c) m = nn.BatchNorm2d(num_features=self.c)
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=True) weight=m.weight, bias=m.bias, training=True)
self.run_layer_test([m], MATCH_INPUT, (self.b, self.c, self.h, self.h)) run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
# Test forward-only for BN inference # Test forward-only for BN inference
m.eval() m.eval()
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=False) weight=m.weight, bias=m.bias, training=False)
self.run_layer_test([m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), test_backward=False) run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
test_backward=False)
def test_bce_raises(self):
shape = (self.b, self.h)
target = torch.randn(shape)
mod = nn.BCELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]:
x = torch.randn(shape, dtype=torch.half)
self.assertRaises(NotImplementedError, fn, x)
class TestTensorCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x.matmul(other)
rhs = lambda x: other.matmul(x)
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_matmul_op_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x @ other
rhs = lambda x: other @ x
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_pow_method_is_float(self):
fn = lambda x: x.pow(2.)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_pow_op_is_float(self):
fn = lambda x: x ** 2.
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_cpu_is_float(self):
fn = lambda x: x.cpu()
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.half: 'torch.FloatTensor'}
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
def test_sum_is_float(self):
fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase): class TestDisabledCasts(unittest.TestCase):
def setUp(self): def setUp(self):
self.handle = amp.init(enabled=False) self.handle = amp.init(enabled=False)
_common_init(self) common_init(self)
def test_disabled_linear(self): def test_disabled_linear(self):
m = nn.Linear(self.h, self.h) m = nn.Linear(self.h, self.h)
......
import unittest
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from .utils import common_init, HALF, FLOAT, DTYPES
class TestPromotion(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_binary_promote_test(self, fns, input_shape):
type_pairs = it.product(DTYPES, DTYPES)
for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_()
y = torch.randn(input_shape, dtype=ytype)
out = fn(x, y)
if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT)
else:
self.assertEqual(out.type(), HALF)
out.float().sum().backward()
self.assertEqual(x.grad.dtype, xtype)
def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y),
lambda x, y : x.atan2(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_cat_matches_widest(self):
shape = self.b
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_half = torch.randn(shape, dtype=torch.half)
out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF)
# TODOs:
# In-place methods on fp16 are errors for fp32
# In-place methods match type of self tensor
if __name__ == '__main__':
unittest.main()
import torch
HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
DTYPES = [torch.half, torch.float]
ALWAYS_HALF = {torch.float: HALF,
torch.half: HALF}
ALWAYS_FLOAT = {torch.float: FLOAT,
torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
torch.half: HALF}
def common_init(test_case):
test_case.h = 64
test_case.b = 16
test_case.c = 16
test_case.k = 3
torch.set_default_tensor_type(torch.cuda.FloatTensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment