".github/vscode:/vscode.git/clone" did not exist on "81e6345ddf893c594f6d76406a388fa012cb0a29"
Commit 18062b69 authored by Deyu Fu's avatar Deyu Fu
Browse files

clean up variance options support by all fused optimizers:

correctly not apply bias correction to epsilon(same as recent upstream change)
correctly not apply bias correction to weight decay(consistent with upstream AdamW)
Make adam_w_mode for FusedAdam/LAMB, to do L2 or Weight Decay (Adam vs AdamW)
Correct document reg_inside_moment differently from adam_w_mode in FusedNovoGrad
Removed legacy eps_mode from FusedAdam
Make internal math type float across fused optimizers
parent 7a219aa9
......@@ -26,10 +26,8 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -37,8 +35,8 @@ class FusedAdam(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
weight_decay=0., amsgrad=False):
if amsgrad:
......@@ -46,7 +44,7 @@ class FusedAdam(torch.optim.Optimizer):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self.adam_w_mode = 1 if adam_w_mode else 0
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
......@@ -103,7 +101,7 @@ class FusedAdam(torch.optim.Optimizer):
beta2,
group['eps'],
group['step'],
self.eps_mode,
self.adam_w_mode,
bias_correction,
group['weight_decay'])
......
......@@ -20,9 +20,8 @@ class FusedLAMB(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
......@@ -34,7 +33,7 @@ class FusedLAMB(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, reg_inside_moment=False,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
if amsgrad:
......@@ -52,7 +51,7 @@ class FusedLAMB(torch.optim.Optimizer):
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
self.moment_mode = 0 if reg_inside_moment else 1
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
def zero_grad(self):
......@@ -129,7 +128,7 @@ class FusedLAMB(torch.optim.Optimizer):
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
self.adam_w_mode,
group['max_grad_norm'])
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
......@@ -143,7 +142,7 @@ class FusedLAMB(torch.optim.Optimizer):
bias_correction,
group['weight_decay'],
grad_averaging,
self.moment_mode,
self.adam_w_mode,
group['max_grad_norm'])
return loss
......@@ -42,7 +42,7 @@ void multi_tensor_adam_cuda(
const float beta2,
const float epsilon,
const int step,
const int eps_mode,
const int mode,
const int bias_correction,
const float weight_decay);
......@@ -59,7 +59,7 @@ void multi_tensor_novograd_cuda(
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int moment_mode,
const int mode,
const int norm_type);
void multi_tensor_lamb_cuda(
......@@ -74,7 +74,7 @@ void multi_tensor_lamb_cuda(
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int moment_mode,
const int mode,
const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -14,10 +14,11 @@
#define ILP 4
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
ADAM_MODE_0 =0, // L2 regularization mode
ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template<typename T>
struct AdamFunctor
......@@ -28,8 +29,10 @@ struct AdamFunctor
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float eps,
const float step_size,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
......@@ -64,10 +67,10 @@ struct AdamFunctor
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
......@@ -79,24 +82,34 @@ struct AdamFunctor
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
T denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(r_v[ii] + eps);
else // Mode 1
denom = sqrtf(r_v[ii]) + eps;
T update = (r_m[ii] / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (step_size * update);
if(mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
}
else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
......@@ -122,20 +135,17 @@ void multi_tensor_adam_cuda(
const float beta2,
const float epsilon,
const int step,
const int eps_mode,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;
float step_size = 0;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
......@@ -149,9 +159,11 @@ void multi_tensor_adam_cuda(
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
step_size,
(adamMode_t) eps_mode,
lr,
(adamMode_t) mode,
weight_decay); )
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -14,9 +14,9 @@
#define ILP 4
typedef enum{
MOMENT_MODE_0 =0, // Momentum with denom/decay, optional grad averaging after
MOMENT_MODE_1 =1 // Momentum without denom/decay
} momentMode_t;
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
......@@ -39,7 +39,7 @@ struct LAMBStage1Functor
const float beta1_correction,
const float beta2_correction,
const float epsilon,
momentMode_t m_mode,
adamMode_t mode,
const float decay,
float* global_grad_norm,
float max_global_grad_norm)
......@@ -103,15 +103,15 @@ struct LAMBStage1Functor
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (m_mode == MOMENT_MODE_0) {
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on grad
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = std::sqrt(next_v_unbiased) + epsilon;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
......@@ -120,7 +120,7 @@ struct LAMBStage1Functor
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = std::sqrt(next_v_unbiased) + epsilon;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
......@@ -220,7 +220,7 @@ void multi_tensor_lamb_cuda(
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int moment_mode,
const int mode,
const float max_grad_norm)
{
using namespace at;
......@@ -263,7 +263,7 @@ void multi_tensor_lamb_cuda(
bias_correction1,
bias_correction2,
epsilon,
(momentMode_t) moment_mode,
(adamMode_t) mode,
weight_decay,
std::get<0>(grad_norm_tuple).data<float>(),
max_grad_norm); )
......
......@@ -14,8 +14,8 @@
#define ILP 4
typedef enum{
MOMENT_MODE_0 =0, // Momentum with denom/decay, optional grad averaging after
MOMENT_MODE_1 =1 // Momentum without denom/decay
MOMENT_MODE_0 =0, // Novograd paper mode, momentum caculation with denom then decay inside
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} momentMode_t;
void multi_tensor_norm_out_cuda(
......@@ -27,6 +27,8 @@ void multi_tensor_norm_out_cuda(
const float beta,
const int norm_type);
using MATH_T = float;
template<typename T>
struct NovoGradFunctor
{
......@@ -37,8 +39,10 @@ struct NovoGradFunctor
const float beta1,
const float beta2,
const float beta3,
const float eps,
const float step_size,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
momentMode_t m_mode,
const float decay,
const float* per_tensor_grad_norm)
......@@ -70,9 +74,9 @@ struct NovoGradFunctor
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
......@@ -83,25 +87,29 @@ struct NovoGradFunctor
r_p[ii] = p[i];
r_m[ii] = m[i];
} else {
r_g[ii] = T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (m_mode == MOMENT_MODE_0) {
T denom = grad_norm + eps;
MATH_T next_v_unbiased = grad_norm / beta2_correction;
MATH_T denom = next_v_unbiased + epsilon;
r_g[ii] = (r_g[ii] / denom) + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];
r_p[ii] = r_p[ii] - (step_size * r_m[ii]);
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
r_p[ii] = r_p[ii] - (lr * next_m_unbiased);
}
else {
r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];
T denom = grad_norm + eps;
T update = (r_m[ii] / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (step_size * update);
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = grad_norm / beta2_correction;
MATH_T denom = next_v_unbiased + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
......@@ -137,14 +145,10 @@ void multi_tensor_novograd_cuda(
using namespace at;
// Handle bias correction mode
float step_size = 0;
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = std::sqrt(1 - std::pow(beta2, step));
}
// Handle grad averaging mode
......@@ -171,8 +175,10 @@ void multi_tensor_novograd_cuda(
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1,
bias_correction2,
epsilon,
step_size,
lr,
(momentMode_t) moment_mode,
weight_decay,
grad_norms.data<float>()); )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment