Unverified Commit b2da92fc authored by Peng's avatar Peng Committed by GitHub
Browse files

Merge pull request #5 from rohithkrn/apex_amp_bfp16

Introduce new optimization levels for BFloat16 training
parents 65490af6 e1267a9a
......@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
......@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum,
scale);
}
// Case 5. bfp16, bfp16, bfp16, No
if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::BFloat16 &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::BFloat16, at::BFloat16>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 6. bfp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::BFloat16 &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::BFloat16, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
......@@ -105,6 +105,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
......
......@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
class MyModel(torch.nn.Module):
def __init__(self, unique):
def __init__(self, unique, dtype=torch.float16):
super(MyModel, self).__init__()
self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=dtype))
@staticmethod
def ops(input, weight0, weight1):
......@@ -51,11 +51,15 @@ class TestAddParamGroup(unittest.TestCase):
optimizer.zero_grad()
def test_add_param_group(self):
for opt_level in ("O0", "O1", "O2", "O3"):
for opt_level in ("O0", "O1", "O2", "O3", "O4", "O5"):
for zero_before_add in (True, False):
for try_accumulation in (True, False):
model0 = MyModel(1)
model1 = MyModel(2)
if opt_level in {"O4", "O5"}:
model0 = MyModel(1, torch.bfloat16)
model1 = MyModel(2, torch.bfloat16)
else:
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
......@@ -89,8 +93,12 @@ class TestAddParamGroup(unittest.TestCase):
[param.data.clone() for param in model1.parameters()]
for how_to_zero in "none", "model", "optimizer":
model0 = MyModel(1)
model1 = MyModel(2)
if opt_level in {"O4", "O5"}:
model0 = MyModel(1, torch.bfloat16)
model1 = MyModel(2, torch.bfloat16)
else:
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
......@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
[param.data.clone() for param in model1.parameters()]
for reference, final in zip(reference_params, final_params):
# TODO: remove the conversion once allclose supports bfloat16 type.
if final.dtype == torch.bfloat16:
final = final.float()
self.assertTrue(torch.allclose(reference.to(final.dtype), final),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
opt_level, how_to_zero, zero_before_add))
......
......@@ -9,7 +9,7 @@ from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()):
......@@ -20,124 +20,233 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
y.float().sum().backward()
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
class TestBasicCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
class _TestBasicCasts(unittest.TestCase):
def _test_linear(self, expected):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
run_layer_test(self, [m, f], expected, (self.b, self.h))
def test_conv2d_is_half(self):
def _test_conv2d(self, expected):
m = nn.Conv2d(self.c, self.c, self.k)
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h))
def test_softmax_is_float(self):
def _test_softmax(self, expected):
m = nn.Softmax(dim=1)
f = ft.partial(F.softmax, dim=1)
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
run_layer_test(self, [m, f], expected, (self.b, self.h))
def test_group_norm_is_float(self):
def _test_group_norm(self, expected):
m = nn.GroupNorm(num_groups=4, num_channels=self.c)
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
def test_mse_loss_is_float(self):
def _test_mse_loss(self, expected):
shape = (self.b, self.h)
target = torch.randn(shape)
mod = nn.MSELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.mse_loss, target=target)
run_layer_test(self, [m], ALWAYS_FLOAT, shape)
run_layer_test(self, [m], expected, shape)
def test_relu_is_match(self):
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
def _test_relu(self, expected):
run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h))
def test_batch_norm_is_match(self):
def _test_batch_norm(self, expected):
m = nn.BatchNorm2d(num_features=self.c)
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=True)
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
# Test forward-only for BN inference
m.eval()
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=False)
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h),
test_backward=False)
class TestBasicCastsHalf(_TestBasicCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
self._test_linear(ALWAYS_HALF)
def test_conv2d_is_half(self):
self._test_conv2d(ALWAYS_HALF)
def test_softmax_is_float(self):
self._test_softmax(ALWAYS_FLOAT)
def test_group_norm_is_float(self):
self._test_group_norm(ALWAYS_FLOAT)
def test_mse_loss_is_float(self):
self._test_mse_loss(ALWAYS_FLOAT)
def test_relu_is_match(self):
self._test_relu(MATCH_INPUT)
def test_batch_norm_is_match(self):
self._test_batch_norm(MATCH_INPUT)
class TestBasicCastsBFloat16(_TestBasicCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_bfloat16(self):
self._test_linear(ALWAYS_BFLOAT16)
def test_conv2d_is_bfloat16(self):
self._test_conv2d(ALWAYS_BFLOAT16)
def test_softmax_is_float(self):
self._test_softmax(ALWAYS_FLOAT)
def test_group_norm_is_float(self):
self._test_group_norm(ALWAYS_FLOAT)
def test_mse_loss_is_float(self):
self._test_mse_loss(ALWAYS_FLOAT)
def test_relu_is_match(self):
self._test_relu(MATCH_INPUT)
def test_batch_norm_is_match(self):
self._test_batch_norm(MATCH_INPUT)
class TestBannedMethods(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def bce_common(self, assertion):
def bce_common(self, assertion, dtype=torch.half):
shape = (self.b, self.h)
target = torch.rand(shape)
mod = nn.BCELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]:
x = torch.rand(shape, dtype=torch.half)
x = torch.rand(shape, dtype=dtype)
assertion(fn, x)
def test_bce_raises_by_default(self):
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
self.bce_common(assertion)
self.bce_common(assertion, dtype=torch.half)
# handle with bfloat16 as patch_type
self.handle._deactivate()
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
self.bce_common(assertion, dtype=torch.bfloat16)
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)
self.bce_common(assertion, dtype=torch.half)
class TestTensorCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
# handle with bfloat16 as patch_type
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16)
self.bce_common(assertion, dtype=torch.bfloat16)
def test_matmul_method_is_half(self):
class _TestTensorCasts(unittest.TestCase):
def _test_matmul_method(self, expected):
other = torch.randn(self.h, self.h)
lhs = lambda x: x.matmul(other)
rhs = lambda x: other.matmul(x)
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
def test_matmul_op_is_half(self):
def _test_matmul_op(self, expected):
other = torch.randn(self.h, self.h)
lhs = lambda x: x @ other
rhs = lambda x: other @ x
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
def test_pow_method_is_float(self):
def _test_pow_method(self, expected):
fn = lambda x: x.pow(2.)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
run_layer_test(self, [fn], expected, (self.b, self.h))
def test_pow_op_is_float(self):
def _test_pow_op(self, expected):
fn = lambda x: x ** 2.
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
run_layer_test(self, [fn], expected, (self.b, self.h))
def test_cpu_is_float(self):
def _test_cpu(self, expected):
fn = lambda x: x.cpu()
run_layer_test(self, [fn], expected, (self.b, self.h))
def _test_sum(self, expected):
fn = lambda x: x.sum()
run_layer_test(self, [fn], expected, (self.b, self.h))
# TODO: maybe more tests on disabled casting?
class TestTensorCastsHalf(_TestTensorCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_half(self):
self._test_matmul_method(ALWAYS_HALF)
def test_matmul_op_is_half(self):
self._test_matmul_op(ALWAYS_HALF)
def test_pow_method_is_float(self):
self._test_pow_method(ALWAYS_FLOAT)
def test_pow_op_is_float(self):
self._test_pow_op(ALWAYS_FLOAT)
def test_cpu_is_float(self):
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.half: 'torch.FloatTensor'}
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
self._test_cpu(always_cpu_float)
def test_sum_is_float(self):
fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
self._test_sum(ALWAYS_FLOAT)
class TestTensorCastsBFloat16(_TestTensorCasts):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_bfloat16(self):
self._test_matmul_method(ALWAYS_BFLOAT16)
def test_matmul_op_is_bfloat16(self):
self._test_matmul_op(ALWAYS_BFLOAT16)
def test_pow_method_is_float(self):
self._test_pow_method(ALWAYS_FLOAT)
def test_pow_op_is_float(self):
self._test_pow_op(ALWAYS_FLOAT)
def test_cpu_is_float(self):
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.bfloat16: 'torch.FloatTensor'}
self._test_cpu(always_cpu_float)
def test_sum_is_float(self):
self._test_sum(ALWAYS_FLOAT)
# TODO: maybe more tests on disabled casting?
if __name__ == '__main__':
unittest.main()
......@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
def tearDown(self):
pass
def train_eval_train_test(self, module, t):
def train_eval_train_test(self, module, t, opt_level):
model = module(t).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
_amp_state.allow_incoming_model_not_fp32 = True
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0)
_amp_state.allow_incoming_model_not_fp32 = False
def training_step():
......@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if model.weight.grad.type() == "torch.cuda.HalfTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
elif model.weight.grad.type() == "torch.cuda.FloatTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
else:
......@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def test_whitelist_module_fp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float16)
self.train_eval_train_test(WhitelistModule, torch.float16, "O1")
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32)
self.train_eval_train_test(WhitelistModule, torch.float32, "O1")
def test_blacklist_module_fp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float16)
self.train_eval_train_test(BlacklistModule, torch.float16, "O1")
def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32)
self.train_eval_train_test(BlacklistModule, torch.float32, "O1")
def test_promote_module_fp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.float16)
self.train_eval_train_test(PromoteModule, torch.float16, "O1")
def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32, "O1")
# opt_level = O4
def test_whitelist_module_bfp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4")
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32, "O4")
def test_blacklist_module_bfp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.bfloat16, "O4")
def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32, "O4")
def test_promote_module_bfp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.bfloat16, "O4")
def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32)
self.train_eval_train_test(PromoteModule, torch.float32, "O4")
if __name__ == '__main__':
......
......@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
class TestCheckpointing(unittest.TestCase):
def setUp(self):
self.initial_lr = 1e-3
self.test_opt_levels = ("O0", "O1", "O2", "O3")
self.test_opt_levels = ("O0", "O1", "O2", "O3", "O4", "O5")
def seed(self):
torch.manual_seed(2809)
......@@ -236,6 +236,7 @@ class TestCheckpointing(unittest.TestCase):
state_dict = model.state_dict()
for key in state_dict:
self.assertFalse('Half' in state_dict[key].type())
self.assertFalse('BFloat16' in state_dict[key].type())
# Check, if model is still trainable
# Create dummy data
......
......@@ -69,7 +69,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if out_type == torch.bfloat16:
out_list = [out.float() for out in out_list]
self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]),
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace))
self.assertTrue(self.overflow_buf.item() == 0,
......@@ -119,9 +122,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16):
for y_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for x_type in (torch.float32, torch.float16, torch.bfloat16):
for y_type in (torch.float32, torch.float16, torch.bfloat16):
for out_type in (torch.float32, torch.float16, torch.bfloat16):
for inplace in (True, False):
if inplace is True and (y_type is not out_type):
continue
......
......@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if out_type == torch.bfloat16:
out_list = [out.float() for out in out_list]
self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]))
self.assertTrue(self.overflow_buf.item() == 0)
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
......@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for in_type in (torch.float32, torch.float16, torch.bfloat16):
for out_type in (torch.float32, torch.float16, torch.bfloat16):
for inplace in (True, False):
if inplace is True and (out_type is not in_type):
continue
......
......@@ -7,18 +7,18 @@ import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT, DTYPES
from utils import common_init, HALF, FLOAT, DTYPES, DTYPES2, MATCH_INPUT
class TestPromotion(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
type_pairs = it.product(DTYPES, DTYPES)
class _TestPromotion(unittest.TestCase):
def run_binary_promote_test(self, fns, input_shape, lp_type, x_inplace=False):
if lp_type == torch.half:
dtypes = DTYPES
elif lp_type == torch.bfloat16:
dtypes = DTYPES2
else:
raise RuntimeError("Creating test class with invalid low_precision type. \
Supported types are torch.half and torch.bfloat16")
type_pairs = it.product(dtypes, dtypes)
for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_()
x_leaf = x
......@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT)
else:
self.assertEqual(out.type(), HALF)
self.assertEqual(out.type(), MATCH_INPUT[lp_type])
out.float().sum().backward()
self.assertEqual(x_leaf.grad.dtype, xtype)
def _test_cat_matches_widest(self, lp_type):
shape = self.b
ys = [torch.randn(shape, dtype=lp_type) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_lp = torch.randn(shape, dtype=lp_type)
out = torch.cat(ys + [x_lp])
self.assertEqual(out.type(), MATCH_INPUT[lp_type])
def _test_inplace_exp_is_error_for_lp(self, lp_type):
xs = torch.randn(self.b)
xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=lp_type)
with self.assertRaises(NotImplementedError):
xs.exp_()
class TestPromotionHalf(_TestPromotion):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.half)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y),
lambda x, y : x.atan2(y)]
self.run_binary_promote_test(fns, (self.b,))
self.run_binary_promote_test(fns, (self.b,), torch.half)
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,))
self.run_binary_promote_test(fns, (self.b,), torch.half)
def test_cat_matches_widest(self):
shape = self.b
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_half = torch.randn(shape, dtype=torch.half)
out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF)
self._test_cat_matches_widest(torch.half)
def test_inplace_exp_is_error_for_half(self):
xs = torch.randn(self.b)
xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=torch.half)
with self.assertRaises(NotImplementedError):
xs.exp_()
self._test_inplace_exp_is_error_for_lp(torch.half)
def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), torch.half, x_inplace=True)
class TestPromotionBFloat16(_TestPromotion):
def setUp(self):
self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,), torch.bfloat16)
def test_cat_matches_widest(self):
self._test_cat_matches_widest(torch.bfloat16)
def test_inplace_exp_is_error_for_bfloat16(self):
self._test_inplace_exp_is_error_for_lp(torch.bfloat16)
def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
self.run_binary_promote_test([fn], (self.b,), torch.bfloat16, x_inplace=True)
if __name__ == '__main__':
unittest.main()
......@@ -2,15 +2,21 @@ import torch
HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
BFLOAT16 = 'torch.cuda.BFloat16Tensor'
DTYPES = [torch.half, torch.float]
DTYPES2 = [torch.bfloat16, torch.float]
ALWAYS_HALF = {torch.float: HALF,
torch.half: HALF}
ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16,
torch.float: BFLOAT16}
ALWAYS_FLOAT = {torch.float: FLOAT,
torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
torch.half: HALF}
torch.half: HALF,
torch.bfloat16: BFLOAT16}
def common_init(test_case):
test_case.h = 64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment