Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c44d7970
Unverified
Commit
c44d7970
authored
Mar 30, 2022
by
LuGY
Committed by
GitHub
Mar 30, 2022
Browse files
[docs] updatad docs of hybrid adam and cpu adam (#552)
parent
014bac0c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
104 additions
and
13 deletions
+104
-13
colossalai/nn/optimizer/cpu_adam.py
colossalai/nn/optimizer/cpu_adam.py
+48
-6
colossalai/nn/optimizer/fused_adam.py
colossalai/nn/optimizer/fused_adam.py
+2
-2
colossalai/nn/optimizer/hybrid_adam.py
colossalai/nn/optimizer/hybrid_adam.py
+48
-5
docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst
docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst
+5
-0
docs/colossalai/colossalai.nn.optimizer.rst
docs/colossalai/colossalai.nn.optimizer.rst
+1
-0
No files found.
colossalai/nn/optimizer/cpu_adam.py
View file @
c44d7970
import
math
import
torch
from
colossalai.registry
import
OPTIMIZERS
@
OPTIMIZERS
.
register_module
class
CPUAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``.
This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.
The GPU part is implemented in an naive way.
CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id
=
0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
...
...
@@ -18,11 +65,6 @@ class CPUAdam(torch.optim.Optimizer):
weight_decay
=
0
,
adamw_mode
=
True
,
simd_log
=
False
):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA.
"""
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
super
(
CPUAdam
,
self
).
__init__
(
model_params
,
default_args
)
...
...
colossalai/nn/optimizer/fused_adam.py
View file @
c44d7970
...
...
@@ -72,8 +72,8 @@ class FusedAdam(torch.optim.Optimizer):
else
:
raise
RuntimeError
(
'FusedAdam requires cuda extensions'
)
def
zero_grad
(
self
):
if
se
lf
.
set_grad
_none
:
def
zero_grad
(
self
,
set_to_none
=
False
):
if
se
t_to
_none
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
...
...
colossalai/nn/optimizer/hybrid_adam.py
View file @
c44d7970
import
torch
from
colossalai.utils
import
multi_tensor_applier
from
colossalai.registry
import
OPTIMIZERS
@
OPTIMIZERS
.
register_module
class
HybridAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
* For parameters updating on CPU, it uses CPUAdam.
* For parameters updating on GPU, it uses FusedAdam.
* Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id
=
0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
...
...
@@ -16,11 +64,6 @@ class HybridAdam(torch.optim.Optimizer):
weight_decay
=
0
,
adamw_mode
=
True
,
simd_log
=
False
):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
"""
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
super
(
HybridAdam
,
self
).
__init__
(
model_params
,
default_args
)
...
...
docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst
0 → 100644
View file @
c44d7970
colossalai.nn.optimizer.hybrid\_adam
====================================
.. automodule:: colossalai.nn.optimizer.hybrid_adam
:members:
docs/colossalai/colossalai.nn.optimizer.rst
View file @
c44d7970
...
...
@@ -13,5 +13,6 @@ colossalai.nn.optimizer
colossalai.nn.optimizer.fused_adam
colossalai.nn.optimizer.fused_lamb
colossalai.nn.optimizer.fused_sgd
colossalai.nn.optimizer.hybrid_adam
colossalai.nn.optimizer.lamb
colossalai.nn.optimizer.lars
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment