Commit 85549903 authored by rohithkrn's avatar rohithkrn
Browse files

enable bfloat16 for optimizers

parent 5cfdc014
...@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer): ...@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer):
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['sum'] = torch.zeros_like(p.data) state['sum'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
h_16.append(state['sum']) h_16.append(state['sum'])
...@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer): ...@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer):
p_32.append(p.data) p_32.append(p.data)
h_32.append(state['sum']) h_32.append(state['sum'])
else: else:
raise RuntimeError('FusedAdagrad only support fp16 and fp32.') raise RuntimeError('FusedAdagrad only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adagrad, multi_tensor_applier(self.multi_tensor_adagrad,
......
...@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer):
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq']) v_32.append(state['exp_avg_sq'])
else: else:
raise RuntimeError('FusedAdam only support fp16 and fp32.') raise RuntimeError('FusedAdam only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam,
......
...@@ -130,7 +130,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -130,7 +130,7 @@ class FusedLAMB(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -141,7 +141,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -141,7 +141,7 @@ class FusedLAMB(torch.optim.Optimizer):
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq']) v_32.append(state['exp_avg_sq'])
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16, bfloat16 and fp32.')
if(len(g_16) > 0): if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
......
...@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p.data)
if p.dtype == torch.float16: if p.dtype in {torch.float16, torch.bfloat16}:
g_16.append(p.grad.data) g_16.append(p.grad.data)
p_16.append(p.data) p_16.append(p.data)
m_16.append(state['exp_avg']) m_16.append(state['exp_avg'])
...@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
p_32.append(p.data) p_32.append(p.data)
m_32.append(state['exp_avg']) m_32.append(state['exp_avg'])
else: else:
raise RuntimeError('FusedNovoGrad only support fp16 and fp32.') raise RuntimeError('FusedNovoGrad only support fp16, bfloat16 and fp32.')
# we store per weight norm as one tensor for one group/precision combination # we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
......
...@@ -690,7 +690,7 @@ void cuda_layer_norm( ...@@ -690,7 +690,7 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->DATA_PTR<scalar_t_0>(), output->DATA_PTR<scalar_t_0>(),
...@@ -793,7 +793,7 @@ void cuda_layer_norm_gradient( ...@@ -793,7 +793,7 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->DATA_PTR<scalar_t_0>(), dout->DATA_PTR<scalar_t_0>(),
......
...@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda( ...@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda(
using namespace at; using namespace at;
// Assume single type across p,g,h now // Assume single type across p,g,h now
DISPATCH_DOUBLE_FLOAT_AND_HALF( DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adagrad", tensor_lists[0][0].scalar_type(), 0, "adagrad",
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdagradFunctor<scalar_t_0>(), epsilon, lr, AdagradFunctor<scalar_t_0>(), epsilon, lr,
......
...@@ -14,22 +14,28 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -14,22 +14,28 @@ class TestFusedAdagrad(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adagrad_option): def gen_param_optim(self, tensors, adagrad_option, apex_only=False):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only:
ref_optim = apex.optimizers.FusedAdagrad(ref_param, **adagrad_option)
else:
ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option) ref_optim = torch.optim.Adagrad(ref_param, **adagrad_option)
tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option) tst_optim = apex.optimizers.FusedAdagrad(tst_param, **adagrad_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param): def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref) p_tst.grad = torch.rand_like(p_tst)
p_tst.grad = p_ref.grad p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = [] half_grads = []
...@@ -38,9 +44,11 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -38,9 +44,11 @@ class TestFusedAdagrad(unittest.TestCase):
p_ref.grad = half_grads[-1].float() / scale p_ref.grad = half_grads[-1].float() / scale
return half_grads return half_grads
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
...@@ -51,22 +59,23 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -51,22 +59,23 @@ class TestFusedAdagrad(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011 nelem = 278011
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5} adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
tensor = torch.rand(nelem, dtype=param_type, device="cuda") tensor = torch.rand(nelem, dtype=param_type, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adagrad_option [tensor], adagrad_option, apex_only=apex_only
) )
for _ in range(self.iters): for _ in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self): def test_float(self):
...@@ -76,6 +85,14 @@ class TestFusedAdagrad(unittest.TestCase): ...@@ -76,6 +85,14 @@ class TestFusedAdagrad(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types(see skip note for test above).
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0} adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}
......
...@@ -15,22 +15,28 @@ class TestFusedAdam(unittest.TestCase): ...@@ -15,22 +15,28 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adam_option): def gen_param_optim(self, tensors, adam_option, apex_only=False):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
if apex_only:
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only:
ref_optim = apex.optimizers.FusedAdam(ref_param, **adam_option)
else:
ref_optim = torch.optim.Adam(ref_param, **adam_option) ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option) tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param): def gen_grad(self, ref_param, tst_param, apex_only=False):
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref) p_tst.grad = torch.rand_like(p_tst)
p_tst.grad = p_ref.grad p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = [] half_grads = []
...@@ -39,9 +45,11 @@ class TestFusedAdam(unittest.TestCase): ...@@ -39,9 +45,11 @@ class TestFusedAdam(unittest.TestCase):
p_ref.grad = half_grads[-1].float() / scale p_ref.grad = half_grads[-1].float() / scale
return half_grads return half_grads
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param, apex_only=False):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
...@@ -50,22 +58,23 @@ class TestFusedAdam(unittest.TestCase): ...@@ -50,22 +58,23 @@ class TestFusedAdam(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float, apex_only=False):
nelem = 278011 nelem = 278011
adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False} 'weight_decay':0, 'amsgrad':False}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], adam_option) self.gen_param_optim([tensor], adam_option, apex_only=apex_only)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param, apex_only=apex_only)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
if not apex_only:
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self): def test_float(self):
...@@ -74,6 +83,14 @@ class TestFusedAdam(unittest.TestCase): ...@@ -74,6 +83,14 @@ class TestFusedAdam(unittest.TestCase):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)
@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked') @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment