Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
33f21d68
Commit
33f21d68
authored
Mar 20, 2020
by
Kexin Yu
Browse files
add FusedLamb in __init__
parent
b4c32010
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
apex/contrib/optimizers/__init__.py
apex/contrib/optimizers/__init__.py
+1
-0
apex/contrib/optimizers/fused_lamb.py
apex/contrib/optimizers/fused_lamb.py
+5
-3
No files found.
apex/contrib/optimizers/__init__.py
View file @
33f21d68
from
.fp16_optimizer
import
FP16_Optimizer
from
.fp16_optimizer
import
FP16_Optimizer
from
.fused_adam
import
FusedAdam
from
.fused_adam
import
FusedAdam
from
.fused_lamb
import
FusedLamb
apex/contrib/optimizers/fused_lamb.py
View file @
33f21d68
...
@@ -14,9 +14,9 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -14,9 +14,9 @@ class FusedLAMB(torch.optim.Optimizer):
* Fusion of the LAMB update's elementwise operations
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
:class:`apex.
contrib.
optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
opt = apex.
contrib.
optimizers.FusedLAMB(model.parameters(), lr = ....)
...
...
opt.step()
opt.step()
...
@@ -70,7 +70,8 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -70,7 +70,8 @@ class FusedLAMB(torch.optim.Optimizer):
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
grad_averaging
=
grad_averaging
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
s
per
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
fused_lamb_cuda
=
importlib
.
import_module
(
"fused_lamb_cuda"
)
fused_lamb_cuda
=
importlib
.
import_module
(
"fused_lamb_cuda"
)
self
.
multi_tensor_lamb
=
fused_lamb_cuda
.
lamb
self
.
multi_tensor_lamb
=
fused_lamb_cuda
.
lamb
...
@@ -80,6 +81,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -80,6 +81,7 @@ class FusedLAMB(torch.optim.Optimizer):
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
print
(
"using apex.contrib.optimizers.FusedLamb"
)
def
zero_grad
(
self
):
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
if
self
.
set_grad_none
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment