Unverified Commit 57f890a7 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Move distributed Adam unit test to contrib dir (#1406)

* Increase default bucket size in distributed Adam

* Move distributed Adam unit test to contrib tests

Integrate into unit testing framework

* Tweak hyperparameters for dist Adam optimizer test

Improves numerical stability so we can keep tight tolerances. Adopting suggestions from @crcrpar.

* Use distributed test infrastructure in distributed Adam unit test

Suggestion from @crcrpar.
parent 81f8ba79
......@@ -124,7 +124,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
model_parallel_rank=0,
average_grad_sync=True,
overlap_grad_sync=True,
bucket_cap_mb=15,
bucket_cap_mb=100,
pipeline_size=2,
fused_grad_copy=False,
max_grad_norm=0.,
......
from contextlib import contextmanager
import os
import torch
from torch.testing._internal import common_utils
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
class SimpleModel(torch.nn.Module):
def __init__(self, num_layers, size):
super().__init__()
self.layers = torch.nn.ModuleList([
torch.nn.Linear(size, size, bias=(i%3==0))
for i in range(num_layers)
])
def forward(self, x):
y = 0
for i, l in enumerate(self.layers):
y += (i+1) * l(x)
return y
def make_models(num_layers, size, dtype=torch.float32, overlap_communication=True):
# Construct models with same parameters
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()):
dist_param.copy_(ref_param)
# Initialize reference model with data-parallelism
rank = torch.distributed.get_rank()
ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[rank],
output_device=rank,
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.1,0.2), 'eps': 0.1, 'weight_decay': 0.1 }
ref_optim = torch.optim.AdamW(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[0::2]},
],
overlap_grad_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
)
return ref_model, ref_optim, dist_model, dist_optim
@contextmanager
def dummy_context():
try:
yield
finally:
pass
class TestDistributedFusedAdam(NcclDistributedTestBase):
seed = 1234
def test_matches_pytorch(
self,
num_layers=11,
layer_size=7,
batch_size=3,
num_steps=3,
micro_batch_steps=3,
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
rtol=1e-5,
atol=1e-5,
):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(
num_layers,
layer_size,
dtype=dtype,
overlap_communication=overlap_communication,
)
# Training loop
for step in range(num_steps):
# Reset gradients
ref_optim.zero_grad()
dist_optim.zero_grad()
# Forward and backward passes
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) + 0.5
dy = torch.rand_like(x) + 0.5
x = x.to(dtype=dtype, device='cuda')
dy = dy.to(dtype=dtype, device='cuda')
# Reference implementation
x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
# Distributed implementation
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist)
backward_context = dummy_context
if use_nosync and micro_step < micro_batch_steps-1:
backward_context = dist_optim.no_sync
with backward_context():
y_dist.backward(dy)
# Check that data tensors match
torch.testing.assert_close(
y_dist, y_ref, rtol=rtol, atol=atol)
torch.testing.assert_close(
x_dist.grad, x_ref.grad, rtol=rtol, atol=atol)
# Optimization step
ref_optim.step()
dist_optim.step()
# Check that parameters match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol)
def test_matches_pytorch_no_overlap(self):
self.test_matches_pytorch(
overlap_communication=False,
use_nosync=False,
)
def test_matches_pytorch_sync_every_step(self):
self.test_matches_pytorch(use_nosync=False)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
num_layers = 11
layer_size = 7
ref_model, ref_optim, dist_model, dist_optim = make_models(
num_layers,
layer_size,
)
# Only perform training step with distributed model
dist_optim.zero_grad()
x = torch.rand(3, layer_size) + 0.5
x = x.to(dtype=torch.float32, device='cuda')
dy = torch.rand_like(x) + 0.5
y = dist_model(x)
y.backward(dy)
dist_optim.step()
# Check that parameters do not match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
dist_param, ref_param,
rtol=1e-5,
atol=1e-5,
)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
import argparse
import os
import random
import torch
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
class TestModel(torch.nn.Module):
def __init__(self, args):
super(TestModel, self).__init__()
self.linear = torch.nn.Sequential(*[
torch.nn.Linear(args.dim, args.dim)
for _ in range(args.layers)
])
def forward(self, x):
y = 0
for i, l in enumerate(self.linear):
y += (i+1) * l(x)
return y
def setup(args):
# Construct models with same parameters
ref_model = TestModel(args).float().cuda()
dist_model = TestModel(args).float().cuda()
with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()):
dist_param.data.copy_(ref_param.data)
ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[args.rank],
output_device=args.rank,
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.5,0.75), 'eps': 0.1, 'weight_decay': 0.1 }
ref_optim = torch.optim.AdamW(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[0::2]},
],
bucket_cap_mb=71/(4*1024*1024),
**optim_args,
)
return ref_model, ref_optim, dist_model, dist_optim
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--steps', type=int, default=3)
parser.add_argument('--batch', type=int, default=5)
parser.add_argument('--dim', type=int, default=7)
parser.add_argument('--layers', type=int, default=11)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
args = parser.parse_args()
return args
def setup_env(args):
# Initialize NCCL
local_rank = args.local_rank
if local_rank < 0:
local_rank = int(os.getenv('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank % torch.cuda.device_count())
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
# Initialize RNG
seed = 42 + args.rank
random.seed(seed)
torch.manual_seed(seed)
return args
def main():
args = parse_args()
args = setup_env(args)
torch.set_printoptions(precision=16)
def assert_allclose(ref_x, dist_x, message):
message = (
f'Rank {args.rank}: {message}\n'
f'Reference Adam: {ref_x}\n'
f'Distributed Adam: {dist_x}\n'
f'Relative error: {torch.abs((ref_x-dist_x)/ref_x)}\n'
)
assert torch.allclose(ref_x, dist_x, atol=args.atol, rtol=args.rtol), message
# Train model with data-parallelism and ZeRO
ref_model, ref_optim, dist_model, dist_optim = setup(args)
for step in range(args.steps):
# Synthetic data
x = torch.randn(args.batch, args.dim).cuda()
dy = torch.randn_like(x).cuda()
# Reference implementation
ref_optim.zero_grad()
x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.step()
# Distributed implementation
dist_optim.zero_grad()
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist)
y_dist.backward(dy)
dist_optim.step()
# Check values
torch.cuda.synchronize()
torch.distributed.barrier()
assert_allclose(
y_ref,
y_dist,
f'inconsistent output in step {step}',
)
assert_allclose(
x_ref.grad,
x_dist.grad,
f'inconsistent input grad in step {step}',
)
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
assert_allclose(
ref_param,
dist_param,
f'inconsistent param {i} in step {step}',
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment