Commit 8405d436 authored by Kexin Yu's avatar Kexin Yu
Browse files

revert to gradient pre-normalization

parent a3ffb8a7
......@@ -104,20 +104,20 @@ struct LAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
// no gradient pre-normalization
MATH_T grad = r_g[ii];
grad = grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * grad * grad;
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * grad * grad;
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
......
......@@ -80,7 +80,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
print("using apex.contrib.optimizers.FusedLamb")
def zero_grad(self):
if self.set_grad_none:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment