Commit 8e5699e4 authored by Kexin Yu's avatar Kexin Yu
Browse files

more debugging

parent 9b96c824
......@@ -118,13 +118,15 @@ class FusedLAMB(torch.optim.Optimizer):
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
print("====after collect")
print("====g_all_32:", g_all_32)
print("====g_all_16:", g_all_16)
# compute grad norm for two lists
g_norm_32, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)
g_norm_16, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)
g_norm_32, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], True)
g_norm_16, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], True)
print("====after multi_tensor_l2norm")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment