Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
427e82cd
Commit
427e82cd
authored
Aug 27, 2019
by
Michael Carilli
Browse files
Updating docstrings for fused optimizers
parent
15648029
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
19 deletions
+115
-19
apex/amp/_initialize.py
apex/amp/_initialize.py
+1
-5
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+31
-8
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+28
-3
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+24
-3
apex/optimizers/fused_sgd.py
apex/optimizers/fused_sgd.py
+22
-0
docs/source/optimizers.rst
docs/source/optimizers.rst
+9
-0
No files found.
apex/amp/_initialize.py
View file @
427e82cd
...
...
@@ -123,11 +123,7 @@ def check_optimizers(optimizers):
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_optim_type
)
+
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (FusedAdam or FusedSGD).
\n
"
"You should not manually wrap your optimizer in either
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"on the specified opt_level (and optional overridden properties)."
)
"optimizers.
\n
"
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
...
...
apex/optimizers/fused_adam.py
View file @
427e82cd
...
...
@@ -4,15 +4,36 @@ from amp_C import multi_tensor_adam
class
FusedAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
"""Implements Adam algorithm.
This version of fused adam implements 2 fusion:
- Fusion of operations within adam optimizer
- Apply operation on a list of tensor in single multi-tensor kernel by group
It is a breaking change over last version, as API changes and it no longer fuse grad norm and loss scaling.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
This version of fused Adam implements 2 fusions:
- Fusion of the Adam update's elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for torch.optim.Adam::
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp,
you may choose any `opt_level`::
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, `opt_level="O1"` is recommended.
.. warning::
A previous version of :class:`FusedAdam` allowed a number of additional arguments to `step`. These additional arguments
are now deprecated and unnecessary.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
...
...
@@ -53,9 +74,11 @@ class FusedAdam(torch.optim.Optimizer):
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
"""
if
any
(
p
is
not
None
for
p
in
[
grads
,
output_params
,
scale
,
grad_norms
]):
raise
RuntimeError
(
'FusedAdam has been updated
, please use with AMP for mixed precision
.'
)
raise
RuntimeError
(
'FusedAdam has been updated
. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments
.'
)
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
...
...
apex/optimizers/fused_lamb.py
View file @
427e82cd
...
...
@@ -3,10 +3,31 @@ from apex.multi_tensor_apply import multi_tensor_applier
class
FusedLAMB
(
torch
.
optim
.
Optimizer
):
"""Implements LAMB algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
"""Implements LAMB algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions:
- Fusion of the LAMB update's elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any `opt_level`::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, `opt_level="O1"` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
...
...
@@ -29,6 +50,10 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
...
...
apex/optimizers/fused_novograd.py
View file @
427e82cd
...
...
@@ -3,8 +3,29 @@ from apex.multi_tensor_apply import multi_tensor_applier
class
FusedNovoGrad
(
torch
.
optim
.
Optimizer
):
"""Implements NovoGrad algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
"""Implements NovoGrad algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused NovoGrad implements 2 fusions:
- Fusion of the NovoGrad update's elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedNovoGrad`'s usage is identical to any Pytorch optimizer::
opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedNovoGrad` may be used with or without Amp. If you wish to use :class:`FusedNovoGrad` with Amp,
you may choose any `opt_level`::
opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, `opt_level="O1"` is recommended.
It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.
More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd
...
...
@@ -35,7 +56,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
.. _Jasper\: An End-to-End Convolutional Neural Acoustic Mode:
.. _Jasper\: An End-to-End Convolutional Neural Acoustic Mode
l
:
https://arxiv.org/abs/1904.03288
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
...
...
apex/optimizers/fused_sgd.py
View file @
427e82cd
...
...
@@ -6,6 +6,28 @@ from apex.multi_tensor_apply import multi_tensor_applier
class
FusedSGD
(
Optimizer
):
r
"""Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused SGD implements 2 fusions:
- Fusion of the SGD update's elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedSGD` may be used as a drop-in replacement for torch.optim.SGD::
opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedSGD` may be used with or without Amp. If you wish to use :class:`FusedSGD` with Amp,
you may choose any `opt_level`::
opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, `opt_level="O1"` is recommended.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
...
...
docs/source/optimizers.rst
View file @
427e82cd
...
...
@@ -12,3 +12,12 @@ apex.optimizers
.. autoclass:: FusedAdam
:members:
.. autoclass:: FusedLAMB
:members:
.. autoclass:: FusedNovoGrad
:members:
.. autoclass:: FusedSGD
:members:
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment