Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
846f7f8a
Commit
846f7f8a
authored
Jun 14, 2022
by
Tim Moon
Browse files
Update documentation to reflect DistributedFusedAdam uses AdamW
Adjust test options to have tighter tolerances.
parent
e2af089c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
25 deletions
+31
-25
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+22
-16
tests/L0/run_optimizers/test_dist_adam.py
tests/L0/run_optimizers/test_dist_adam.py
+9
-9
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
846f7f8a
...
...
@@ -12,7 +12,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
from
torch.distributed.distributed_c10d
import
_get_default_group
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
"""Adam optimizer with ZeRO algorithm.
"""Adam
W
optimizer with ZeRO algorithm.
Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
...
...
@@ -24,9 +24,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic Optimization`_
and ZeRO in
`ZeRO: Memory Optimizations Toward Training Trillion Parameter Models`_
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts
...
...
@@ -87,6 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
...
...
@@ -327,10 +329,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_start
=
min
(
max
(
shard_start
,
0
),
self
.
shard_size
)
shard_end
=
min
(
max
(
shard_end
,
0
),
self
.
shard_size
)
in_local_shard
=
shard_start
<
shard_end
shard_bucket_start
=
shard_start
+
self
.
shard_size
*
shard_id
shard_bucket_end
=
shard_bucket_start
+
shard_end
-
shard_start
shard_param_start
=
shard_bucket_start
-
bucket_start
+
param_start
shard_param_end
=
shard_param_start
+
shard_end
-
shard_start
if
in_local_shard
:
shard_bucket_start
=
shard_start
+
self
.
shard_size
*
shard_id
shard_bucket_end
=
shard_bucket_start
+
shard_end
-
shard_start
shard_param_start
=
shard_bucket_start
-
bucket_start
+
param_start
shard_param_end
=
shard_param_start
+
shard_end
-
shard_start
else
:
shard_bucket_start
,
shard_bucket_end
=
None
,
None
shard_param_start
,
shard_param_end
=
None
,
None
# Record fragment info
fragment
=
{
...
...
@@ -761,14 +767,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Fuse param fragments if possible
if
len
(
buffers
)
==
1
:
for
group_id
in
buffers
.
keys
()
:
buffers
[
group_id
]
=
[(
bucket
[
'params_shard'
],
bucket
[
'exp_avg_shard'
],
bucket
[
'exp_avg_sq_shard'
],
bucket
[
'grads_shard'
],
params_shard_copy
,
)]
group_id
=
list
(
buffers
.
keys
()
)[
0
]
buffers
[
group_id
]
=
[(
bucket
[
'params_shard'
],
bucket
[
'exp_avg_shard'
],
bucket
[
'exp_avg_sq_shard'
],
bucket
[
'grads_shard'
],
params_shard_copy
,
)]
# Apply optimizer step to each param group
for
group_id
,
group_buffers
in
buffers
.
items
():
...
...
tests/L0/run_optimizers/test_dist_adam.py
View file @
846f7f8a
...
...
@@ -16,8 +16,8 @@ class TestModel(torch.nn.Module):
def
forward
(
self
,
x
):
y
=
0
for
l
in
self
.
linear
:
y
+=
l
(
x
)
for
i
,
l
in
enumerate
(
self
.
linear
)
:
y
+=
(
i
+
1
)
*
l
(
x
)
return
y
def
setup
(
args
):
...
...
@@ -36,17 +36,17 @@ def setup(args):
)
# Construct optimizers with same hyperparameters
optim_args
=
{
'lr'
:
1
e-3
,
'eps'
:
1e-6
,
'weight_decay'
:
0.
0
1
}
ref_optim
=
torch
.
optim
.
Adam
(
optim_args
=
{
'lr'
:
1
,
'betas'
:
(
0.5
,
0.75
)
,
'eps'
:
0.1
,
'weight_decay'
:
0.1
}
ref_optim
=
torch
.
optim
.
Adam
W
(
[
{
'params'
:
list
(
ref_model
.
parameters
())[
1
::
2
],
'lr'
:
5e-3
},
{
'params'
:
list
(
ref_model
.
parameters
())[
1
::
2
],
'lr'
:
0.5
},
{
'params'
:
list
(
ref_model
.
parameters
())[
0
::
2
]},
],
**
optim_args
,
)
dist_optim
=
DistributedFusedAdam
(
[
{
'params'
:
list
(
dist_model
.
parameters
())[
1
::
2
],
'lr'
:
5e-3
},
{
'params'
:
list
(
dist_model
.
parameters
())[
1
::
2
],
'lr'
:
0.5
},
{
'params'
:
list
(
dist_model
.
parameters
())[
0
::
2
]},
],
bucket_cap_mb
=
71
/
(
4
*
1024
*
1024
),
...
...
@@ -59,12 +59,12 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
11
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
7
)
parser
.
add_argument
(
'--layers'
,
type
=
int
,
default
=
11
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-
3
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1e-
3
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-
5
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1e-
5
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment