Commit 846f7f8a authored by Tim Moon's avatar Tim Moon
Browse files

Update documentation to reflect DistributedFusedAdam uses AdamW

Adjust test options to have tighter tolerances.
parent e2af089c
......@@ -12,7 +12,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group
class DistributedFusedAdam(torch.optim.Optimizer):
"""Adam optimizer with ZeRO algorithm.
"""AdamW optimizer with ZeRO algorithm.
Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
......@@ -24,9 +24,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic Optimization`_
and ZeRO in
`ZeRO: Memory Optimizations Toward Training Trillion Parameter Models`_
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts
......@@ -87,6 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
......@@ -327,10 +329,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_start = min(max(shard_start, 0), self.shard_size)
shard_end = min(max(shard_end, 0), self.shard_size)
in_local_shard = shard_start < shard_end
shard_bucket_start = shard_start + self.shard_size*shard_id
shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
if in_local_shard:
shard_bucket_start = shard_start + self.shard_size*shard_id
shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
else:
shard_bucket_start, shard_bucket_end = None, None
shard_param_start, shard_param_end = None, None
# Record fragment info
fragment = {
......@@ -761,14 +767,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Fuse param fragments if possible
if len(buffers) == 1:
for group_id in buffers.keys():
buffers[group_id] = [(
bucket['params_shard'],
bucket['exp_avg_shard'],
bucket['exp_avg_sq_shard'],
bucket['grads_shard'],
params_shard_copy,
)]
group_id = list(buffers.keys())[0]
buffers[group_id] = [(
bucket['params_shard'],
bucket['exp_avg_shard'],
bucket['exp_avg_sq_shard'],
bucket['grads_shard'],
params_shard_copy,
)]
# Apply optimizer step to each param group
for group_id, group_buffers in buffers.items():
......
......@@ -16,8 +16,8 @@ class TestModel(torch.nn.Module):
def forward(self, x):
y = 0
for l in self.linear:
y += l(x)
for i, l in enumerate(self.linear):
y += (i+1) * l(x)
return y
def setup(args):
......@@ -36,17 +36,17 @@ def setup(args):
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
ref_optim = torch.optim.Adam(
optim_args = { 'lr': 1, 'betas': (0.5,0.75), 'eps': 0.1, 'weight_decay': 0.1 }
ref_optim = torch.optim.AdamW(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 5e-3},
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 5e-3},
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[0::2]},
],
bucket_cap_mb=71/(4*1024*1024),
......@@ -59,12 +59,12 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=11)
parser.add_argument('--steps', type=int, default=3)
parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--rtol', type=float, default=1e-3)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment