Commit 33f21d68 authored by Kexin Yu's avatar Kexin Yu
Browse files

add FusedLamb in __init__

parent b4c32010
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_lamb import FusedLamb
...@@ -14,9 +14,9 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -14,9 +14,9 @@ class FusedLAMB(torch.optim.Optimizer):
* Fusion of the LAMB update's elementwise operations * Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer:: :class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)
... ...
opt.step() opt.step()
...@@ -70,7 +70,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -70,7 +70,8 @@ class FusedLAMB(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults) s
per(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda") fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb self.multi_tensor_lamb = fused_lamb_cuda.lamb
...@@ -80,6 +81,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -80,6 +81,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
print("using apex.contrib.optimizers.FusedLamb")
def zero_grad(self): def zero_grad(self):
if self.set_grad_none: if self.set_grad_none:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment