"...text-generation-inference.git" did not exist on "04d4765bad5707458955189fbf39e8b485de5cbd"
Unverified Commit 59d2f7ac authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Add unit tests for Fused NovoGrad (#1065)

* Add unit tests for fused-novograd

* Fix: tensors should reside on the same device

* Fix: Cudastream should be called on the same device on which the tensors reside on. Found this during debugging fused novograd multi-device unit test

* fixed issues mentioned in the comments
parent a651e2c2
...@@ -79,7 +79,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -79,7 +79,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# Creating the overflow buffer on the same device as the params tensors.
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_novograd = amp_C.multi_tensor_novograd self.multi_tensor_novograd = amp_C.multi_tensor_novograd
else: else:
raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions') raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions')
...@@ -158,8 +160,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -158,8 +160,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
if 'exp_avg_sq' not in group: if 'exp_avg_sq' not in group:
group['exp_avg_sq'] = [None, None] group['exp_avg_sq'] = [None, None]
if group['init_zero']: if group['init_zero']:
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16)).contiguous().fill_(0) # Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32)).contiguous().fill_(0) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0)
else: # init with first step norm, so first blend have no effect else: # init with first step norm, so first blend have no effect
if group['norm_type'] == 0: if group['norm_type'] == 0:
v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16] v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]
...@@ -169,8 +172,9 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -169,8 +172,9 @@ class FusedNovoGrad(torch.optim.Optimizer):
v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32] v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
else: else:
raise RuntimeError('FusedNovoGrad only support l2/inf norm now.') raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16) # Creating the following parameters on the same device as the params tensors.
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0]["params"][0].device)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0]["params"][0].device)
else: else:
assert(len(g_16) == group['exp_avg_sq'][0].numel()) assert(len(g_16) == group['exp_avg_sq'][0].numel())
assert(len(g_32) == group['exp_avg_sq'][1].numel()) assert(len(g_32) == group['exp_avg_sq'][1].numel())
......
...@@ -427,6 +427,11 @@ void multi_tensor_norm_out_cuda( ...@@ -427,6 +427,11 @@ void multi_tensor_norm_out_cuda(
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); auto ret = at::empty({1}, output.options());
// Adding the following device guard since it happens sometimes that the
// tensors are on one device and the cuda stream is on another device which
// results in ILLEGAL MEM ACCESS error.
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
cleanup_v2<<<ntensors, 512, 0, stream>>>( cleanup_v2<<<ntensors, 512, 0, stream>>>(
output.DATA_PTR<float>(), output.DATA_PTR<float>(),
......
import torch
from torch.optim import Optimizer
import math
import apex
import unittest
from test_fused_optimizer import TestFusedOptimizer
from itertools import product
class Novograd(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(Novograd, self).__init__(params, defaults)
def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(exp_avg, alpha=-group['lr'])
return loss
class TestFusedNovoGrad(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedNovoGrad, self).__init__(*args, **kwargs)
# The options for NovoGrad and FusedNovoGrad are very specific if they
# are expected to behave the same.
self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False}
self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False,
'bias_correction':False, 'reg_inside_moment':True,
'norm_type':2, 'init_zero':False, 'set_grad_none':True}
self.ref_optim = Novograd
self.fused_optim = apex.optimizers.FusedNovoGrad
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:1", "cuda:0")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
torch.cuda.synchronize()
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options, self.tst_options
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
unittest.main()
...@@ -2,9 +2,11 @@ import unittest ...@@ -2,9 +2,11 @@ import unittest
import os import os
import random import random
import math
import torch import torch
import apex import apex
from itertools import product from itertools import product
from torch.optim import Optimizer
class TestFusedOptimizer(unittest.TestCase): class TestFusedOptimizer(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
...@@ -16,7 +18,14 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -16,7 +18,14 @@ class TestFusedOptimizer(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, options): def gen_param_optim(self, tensors, options, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
...@@ -24,7 +33,7 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -24,7 +33,7 @@ class TestFusedOptimizer(unittest.TestCase):
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = self.ref_optim(ref_param, **options) ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **options) tst_optim = self.fused_optim(tst_param, **tst_options)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -54,9 +63,18 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -54,9 +63,18 @@ class TestFusedOptimizer(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float, device='cuda'): def gen_single_type_test(self, param_type=torch.float, device='cuda'):
nelem = 278011 nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, 'tst_options'):
self.tst_options = self.options
tensor = torch.rand(nelem, dtype=param_type, device=device) tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], self.options) self.gen_param_optim([tensor], self.options, self.tst_options)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
...@@ -89,7 +107,6 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -89,7 +107,6 @@ class TestFusedAdam(TestFusedOptimizer):
with torch.cuda.device(current_dev): with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev) self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked') @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
...@@ -263,8 +280,5 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -263,8 +280,5 @@ class TestFusedSGD(TestFusedOptimizer):
with torch.cuda.device(current_dev): with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev) self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment