Commit a0bf956a authored by Kexin Yu's avatar Kexin Yu
Browse files

more debugging

parent feb93a2a
......@@ -83,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
print("debugging LAMB")
def zero_grad(self):
if self.set_grad_none:
......@@ -130,6 +131,8 @@ class FusedLAMB(torch.optim.Optimizer):
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
print("====global_grad_norm:", global_grad_norm)
print("====max_grad_norm:", max_grad_norm)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment