Commit 33f21d68 authored by Kexin Yu's avatar Kexin Yu
Browse files

add FusedLamb in __init__

parent b4c32010
from .fp16_optimizer import FP16_Optimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLamb
......@@ -14,9 +14,9 @@ class FusedLAMB(torch.optim.Optimizer):
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
:class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
......@@ -70,7 +70,8 @@ class FusedLAMB(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
s
per(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb
......@@ -80,6 +81,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
print("using apex.contrib.optimizers.FusedLamb")
def zero_grad(self):
if self.set_grad_none:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment