Commit feb93a2a authored by Kexin Yu's avatar Kexin Yu
Browse files

check empty lists

parent 8e5699e4
......@@ -83,7 +83,6 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
print("apex.contrib.optimiziers.FusedLAMB: testing global gradient clipping")
def zero_grad(self):
if self.set_grad_none:
......@@ -117,24 +116,20 @@ class FusedLAMB(torch.optim.Optimizer):
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
print("====after collect")
print("====g_all_32:", g_all_32)
print("====g_all_16:", g_all_16)
g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
g_norm_32, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], True)
g_norm_16, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], True)
print("====after multi_tensor_l2norm")
if len(g_all_32) > 0:
g_norm_32, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)
if len(g_all_16) > 0:
g_norm_16, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
print("====global_grad_norm:", global_grad_norm)
print("====max_grad_norm:", max_grad_norm)
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment