Commit 8405d436 authored by Kexin Yu's avatar Kexin Yu
Browse files

revert to gradient pre-normalization

parent a3ffb8a7
...@@ -104,20 +104,20 @@ struct LAMBStage1Functor ...@@ -104,20 +104,20 @@ struct LAMBStage1Functor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
// no gradient pre-normalization MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T grad = r_g[ii]; // L2 on scaled grad
grad = grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * grad * grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon; MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
} }
else { else {
MATH_T grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * grad * grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon; MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
......
...@@ -80,7 +80,6 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -80,7 +80,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
print("using apex.contrib.optimizers.FusedLamb")
def zero_grad(self): def zero_grad(self):
if self.set_grad_none: if self.set_grad_none:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment