Commit 85e4af76 authored by Kexin Yu's avatar Kexin Yu
Browse files

make use_nvlamb a class attribute for FusedLAMB

parent 3fd3e2c8
...@@ -84,6 +84,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -84,6 +84,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self): def zero_grad(self):
if self.set_grad_none: if self.set_grad_none:
...@@ -190,7 +191,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -190,7 +191,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode, self.adam_w_mode,
global_grad_norm, global_grad_norm,
max_grad_norm, max_grad_norm,
use_nvlamb) self.use_nvlamb)
if(len(g_32) > 0): if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf, self._dummy_overflow_buf,
...@@ -206,6 +207,6 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -206,6 +207,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode, self.adam_w_mode,
global_grad_norm, global_grad_norm,
max_grad_norm, max_grad_norm,
use_nvlamb) self.use_nvlamb)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment