Commit 37a1c121 authored by Deyu Fu's avatar Deyu Fu
Browse files

add multi-precision support for novograd, clean import

parent 8599b854
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from amp_C import multi_tensor_novograd
class NovoGrad(torch.optim.Optimizer): class NovoGrad(torch.optim.Optimizer):
...@@ -54,8 +53,15 @@ class NovoGrad(torch.optim.Optimizer): ...@@ -54,8 +53,15 @@ class NovoGrad(torch.optim.Optimizer):
grad_averaging=grad_averaging, norm_type=norm_type, grad_averaging=grad_averaging, norm_type=norm_type,
init_zero=init_zero) init_zero=init_zero)
super(NovoGrad, self).__init__(params, defaults) super(NovoGrad, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_novograd = amp_C.multi_tensor_novograd
else:
raise RuntimeError('apex.optimizers.NovoGrad requires cuda extensions')
self.moment_mode = 0 if reg_inside_moment else 1 self.moment_mode = 0 if reg_inside_moment else 1
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
def zero_grad(self): def zero_grad(self):
...@@ -90,7 +96,8 @@ class NovoGrad(torch.optim.Optimizer): ...@@ -90,7 +96,8 @@ class NovoGrad(torch.optim.Optimizer):
group['step'] = 1 group['step'] = 1
# create lists for multi-tensor apply # create lists for multi-tensor apply
p_list, g_list, m1_list = [], [], [] g_16, p_16, m_16 = [], [], []
g_32, p_32, m_32 = [], [], []
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -104,39 +111,69 @@ class NovoGrad(torch.optim.Optimizer): ...@@ -104,39 +111,69 @@ class NovoGrad(torch.optim.Optimizer):
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p.data)
p_list.append(p.data) if p.dtype == torch.float16:
g_list.append(p.grad.data) g_16.append(p.grad.data)
m1_list.append(state['exp_avg']) p_16.append(p.data)
m_16.append(state['exp_avg'])
# we will store per weight norm as one tensor for a group elif p.dtype == torch.float32:
# different rom optim.Adam, we store norm here(not ^2) so we can unify 2 norm type g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
else:
raise RuntimeError('NovoGrad only support fp16 and fp32.')
# we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
if 'exp_avg_sq' not in group: if 'exp_avg_sq' not in group:
group['exp_avg_sq'] = [None, None]
if group['init_zero']: if group['init_zero']:
group['exp_avg_sq'] = torch.cuda.FloatTensor(len(g_list)).contiguous().fill_(0) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16)).contiguous().fill_(0)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32)).contiguous().fill_(0)
else: # init with first step norm, so first blend have no effect else: # init with first step norm, so first blend have no effect
if group['norm_type'] == 0: if group['norm_type'] == 0:
m2 = [torch.max(torch.abs(g)).item() for g in g_list] v_16 = [torch.max(torch.abs(g)).item() for g in g_16]
v_32 = [torch.max(torch.abs(g)).item() for g in g_32]
elif group['norm_type'] == 2: elif group['norm_type'] == 2:
m2 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_list] v_16 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_16]
v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]
else: else:
raise RuntimeError('NovoGrad only support l2/inf norm now.') raise RuntimeError('NovoGrad only support l2/inf norm now.')
group['exp_avg_sq'] = torch.cuda.FloatTensor(m2) group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16)
group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32)
else: else:
assert(len(g_list) == group['exp_avg_sq'].numel()) assert(len(g_16) == group['exp_avg_sq'][0].numel())
assert(len(g_32) == group['exp_avg_sq'][1].numel())
multi_tensor_applier(multi_tensor_novograd,
self.dummy_overflow_buf, if(len(g_16) > 0):
[g_list, p_list, m1_list], multi_tensor_applier(self.multi_tensor_novograd,
group['exp_avg_sq'], self._dummy_overflow_buf,
group['lr'], [g_16, p_16, m_16],
beta1, group['exp_avg_sq'][0],
beta2, group['lr'],
group['eps'], beta1,
group['step'], beta2,
bias_correction, group['eps'],
group['weight_decay'], group['step'],
grad_averaging, bias_correction,
self.moment_mode, group['weight_decay'],
group['norm_type']) grad_averaging,
self.moment_mode,
group['norm_type'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_novograd,
self._dummy_overflow_buf,
[g_32, p_32, m_32],
group['exp_avg_sq'][1],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
group['norm_type'])
return loss return loss
import torch import torch
from torch.optim import Optimizer
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
class SGD(Optimizer): class SGD(torch.optim.Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum). r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__. `On the importance of initialization and momentum in deep learning`__.
...@@ -62,7 +60,7 @@ class SGD(Optimizer): ...@@ -62,7 +60,7 @@ class SGD(Optimizer):
self.multi_tensor_axpby = amp_C.multi_tensor_axpby self.multi_tensor_axpby = amp_C.multi_tensor_axpby
self.multi_tensor_sgd = amp_C.multi_tensor_sgd self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else: else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') raise RuntimeError('apex.optimizers.SGD requires cuda extensions')
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment