Commit 85e4af76 authored by Kexin Yu's avatar Kexin Yu
Browse files

make use_nvlamb a class attribute for FusedLAMB

parent 3fd3e2c8
......@@ -84,6 +84,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self):
if self.set_grad_none:
......@@ -190,7 +191,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode,
global_grad_norm,
max_grad_norm,
use_nvlamb)
self.use_nvlamb)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
......@@ -206,6 +207,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode,
global_grad_norm,
max_grad_norm,
use_nvlamb)
self.use_nvlamb)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment