Commit 846f7f8a authored by Tim Moon's avatar Tim Moon
Browse files

Update documentation to reflect DistributedFusedAdam uses AdamW

Adjust test options to have tighter tolerances.
parent e2af089c
...@@ -12,7 +12,7 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -12,7 +12,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""Adam optimizer with ZeRO algorithm. """AdamW optimizer with ZeRO algorithm.
Currently GPU-only. Requires Apex to be installed via Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -24,9 +24,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -24,9 +24,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
the parallel processes. Options are provided to overlap the the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute. gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic Optimization`_ Adam was proposed in `Adam: A Method for Stochastic
and ZeRO in Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
`ZeRO: Memory Optimizations Toward Training Trillion Parameter Models`_ and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts params (iterable): iterable of parameters to optimize or dicts
...@@ -87,6 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -87,6 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models: .. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054 https://arxiv.org/abs/1910.02054
...@@ -327,10 +329,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -327,10 +329,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_start = min(max(shard_start, 0), self.shard_size) shard_start = min(max(shard_start, 0), self.shard_size)
shard_end = min(max(shard_end, 0), self.shard_size) shard_end = min(max(shard_end, 0), self.shard_size)
in_local_shard = shard_start < shard_end in_local_shard = shard_start < shard_end
shard_bucket_start = shard_start + self.shard_size*shard_id if in_local_shard:
shard_bucket_end = shard_bucket_start + shard_end - shard_start shard_bucket_start = shard_start + self.shard_size*shard_id
shard_param_start = shard_bucket_start - bucket_start + param_start shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_end = shard_param_start + shard_end - shard_start shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
else:
shard_bucket_start, shard_bucket_end = None, None
shard_param_start, shard_param_end = None, None
# Record fragment info # Record fragment info
fragment = { fragment = {
...@@ -761,14 +767,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -761,14 +767,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Fuse param fragments if possible # Fuse param fragments if possible
if len(buffers) == 1: if len(buffers) == 1:
for group_id in buffers.keys(): group_id = list(buffers.keys())[0]
buffers[group_id] = [( buffers[group_id] = [(
bucket['params_shard'], bucket['params_shard'],
bucket['exp_avg_shard'], bucket['exp_avg_shard'],
bucket['exp_avg_sq_shard'], bucket['exp_avg_sq_shard'],
bucket['grads_shard'], bucket['grads_shard'],
params_shard_copy, params_shard_copy,
)] )]
# Apply optimizer step to each param group # Apply optimizer step to each param group
for group_id, group_buffers in buffers.items(): for group_id, group_buffers in buffers.items():
......
...@@ -16,8 +16,8 @@ class TestModel(torch.nn.Module): ...@@ -16,8 +16,8 @@ class TestModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
y = 0 y = 0
for l in self.linear: for i, l in enumerate(self.linear):
y += l(x) y += (i+1) * l(x)
return y return y
def setup(args): def setup(args):
...@@ -36,17 +36,17 @@ def setup(args): ...@@ -36,17 +36,17 @@ def setup(args):
) )
# Construct optimizers with same hyperparameters # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 } optim_args = { 'lr': 1, 'betas': (0.5,0.75), 'eps': 0.1, 'weight_decay': 0.1 }
ref_optim = torch.optim.Adam( ref_optim = torch.optim.AdamW(
[ [
{'params': list(ref_model.parameters())[1::2], 'lr': 5e-3}, {'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[0::2]}, {'params': list(ref_model.parameters())[0::2]},
], ],
**optim_args, **optim_args,
) )
dist_optim = DistributedFusedAdam( dist_optim = DistributedFusedAdam(
[ [
{'params': list(dist_model.parameters())[1::2], 'lr': 5e-3}, {'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[0::2]}, {'params': list(dist_model.parameters())[0::2]},
], ],
bucket_cap_mb=71/(4*1024*1024), bucket_cap_mb=71/(4*1024*1024),
...@@ -59,12 +59,12 @@ def parse_args(): ...@@ -59,12 +59,12 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=11) parser.add_argument('--steps', type=int, default=3)
parser.add_argument('--batch', type=int, default=5) parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=7) parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=11) parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--atol', type=float, default=1e-3) parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-3) parser.add_argument('--rtol', type=float, default=1e-5)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment