Commit f560bd0b authored by Kexin Yu's avatar Kexin Yu
Browse files

save a sync when calculating global gradient norm

parent ac4ef2d6
import torch
import math
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
......@@ -123,14 +122,17 @@ class FusedLAMB(torch.optim.Optimizer):
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0].item()
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment