".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "6425d46cfc15b5a7073a30cee4024434eeeac7a0"
Commit 92186863 authored by Kexin Yu's avatar Kexin Yu
Browse files

import amp_C.multi_tensor_l2norm

parent 96b017a8
...@@ -73,9 +73,11 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -73,9 +73,11 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda") fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb self.multi_tensor_lamb = fused_lamb_cuda.lamb
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
else: else:
raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions') raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment