Commit 7fa74925 authored by Deyu Fu's avatar Deyu Fu Committed by mcarilli
Browse files

Fix issues in fused_dam (#469)

* move import of amp_C to __init__()

* make fp16/32 separate lists to support mixed param types, disable double test

* make zero_grad consistent between adam/novograd/lamb
parent 35a85789
import torch
from apex.multi_tensor_apply import multi_tensor_applier
from amp_C import multi_tensor_adam
class FusedAdam(torch.optim.Optimizer):
......@@ -51,6 +50,8 @@ class FusedAdam(torch.optim.Optimizer):
(default: False) NOT SUPPORTED in FusedAdam!
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -60,7 +61,7 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
weight_decay=0., amsgrad=False):
weight_decay=0., amsgrad=False, set_grad_none=True):
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
......@@ -68,7 +69,22 @@ class FusedAdam(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = amp_C.multi_tensor_adam
else:
raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions')
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
"""Performs a single optimization step.
......@@ -97,7 +113,8 @@ class FusedAdam(torch.optim.Optimizer):
group['step'] = 1
# create lists for multi-tensor apply
p_list, g_list, m1_list, m2_list = [], [], [], []
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
......@@ -113,22 +130,43 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
p_list.append(p.data)
g_list.append(p.grad.data)
m1_list.append(state['exp_avg'])
m2_list.append(state['exp_avg_sq'])
multi_tensor_applier(multi_tensor_adam,
self.dummy_overflow_buf,
[g_list, p_list, m1_list, m2_list],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
return loss
......@@ -68,9 +68,6 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_double(self):
self.gen_single_type_test(param_type=torch.double)
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment