Commit 92186863 authored by Kexin Yu's avatar Kexin Yu
Browse files

import amp_C.multi_tensor_l2norm

parent 96b017a8
......@@ -73,9 +73,11 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
else:
raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment