Commit d38e6fe4 authored by Kexin Yu's avatar Kexin Yu
Browse files

.item()

parent a0bf956a
...@@ -41,8 +41,8 @@ struct LAMBStage1Functor ...@@ -41,8 +41,8 @@ struct LAMBStage1Functor
const float epsilon, const float epsilon,
adamMode_t mode, adamMode_t mode,
const float decay, const float decay,
float global_grad_norm, const float global_grad_norm,
float max_global_grad_norm) const float max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
......
...@@ -83,7 +83,6 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -83,7 +83,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
print("debugging LAMB")
def zero_grad(self): def zero_grad(self):
if self.set_grad_none: if self.set_grad_none:
...@@ -116,23 +115,22 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -116,23 +115,22 @@ class FusedLAMB(torch.optim.Optimizer):
g_all_16.append(p.grad.data) g_all_16.append(p.grad.data)
else: else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
:q!
g_norm_32, g_norm_16 = 0.0, 0.0 g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32, _ = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf, self._dummy_overflow_buf,
[g_all_32], False) [g_all_32], False)[0].item()
if len(g_all_16) > 0: if len(g_all_16) > 0:
g_norm_16, _ = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf, self._dummy_overflow_buf,
[g_all_16], False) [g_all_16], False)[0].item()
# blend two grad norms to get global grad norm # blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16) global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
print("====global_grad_norm:", global_grad_norm)
print("====max_grad_norm:", max_grad_norm)
for group in self.param_groups: for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment