Unverified Commit cd499737 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add features to distributed Adam for Megatron support (#1414)

* Add features to distributed Adam for Megatron support

Support gradient clipping, gradient scaling, FP32 grad accumulation, and multiple dtypes and devices.

* Restore closure arg to distributed Adam

Review suggestion from @crcrpar
parent bf3c008e
......@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module):
y += (i+1) * l(x)
return y
def make_models(num_layers, size, dtype=torch.float32, overlap_communication=True):
def make_models(
num_layers,
size,
dtype=torch.float32,
device='cuda',
overlap_communication=True,
):
# Construct models with same parameters
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda')
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()):
......@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru
rank = torch.distributed.get_rank()
ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model,
device_ids=[rank],
output_device=rank,
device_ids=[rank] if device=='cuda' else None,
output_device=rank if device=='cuda' else None,
)
# Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.1,0.2), 'eps': 0.1, 'weight_decay': 0.1 }
optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1)
ref_optim = torch.optim.AdamW(
[
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5},
{'params': list(ref_model.parameters())[1::2], 'lr': 0.2},
{'params': list(ref_model.parameters())[0::2]},
],
**optim_args,
)
dist_optim = DistributedFusedAdam(
[
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5},
{'params': list(dist_model.parameters())[1::2], 'lr': 0.2},
{'params': list(dist_model.parameters())[0::2]},
],
overlap_grad_sync=overlap_communication,
......@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
rtol=1e-5,
atol=1e-5,
device='cuda',
rtol=None,
atol=None,
):
torch.manual_seed(self.seed + self.rank)
......@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers,
layer_size,
dtype=dtype,
device=device,
overlap_communication=overlap_communication,
)
......@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) + 0.5
dy = torch.rand_like(x) + 0.5
x = x.to(dtype=dtype, device='cuda')
dy = dy.to(dtype=dtype, device='cuda')
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.to(dtype=dtype, device=device)
dy = dy.to(dtype=dtype, device=device)
# Reference implementation
x_ref = x.detach().clone().requires_grad_(True)
......@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step()
# Check that parameters match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol)
......@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
def test_matches_pytorch_sync_every_step(self):
self.test_matches_pytorch(use_nosync=False)
def test_matches_pytorch_fp64(self):
self.test_matches_pytorch(
dtype=torch.float64,
rtol=1.3e-6,
atol=1e-5,
)
def test_matches_pytorch_fp16(self):
self.test_matches_pytorch(
dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
......@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step()
# Check that parameters do not match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(),
dist_model.parameters())):
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
dist_param, ref_param,
rtol=1e-5,
atol=1e-5,
)
def test_clip_grad_norm(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, -1, 1, -1, 1, -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
y_ref.backward(dy.detach())
ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5)
ref_optim.step()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
y_dist.backward(dy.detach())
dist_grad_norm = dist_optim.clip_grad_norm(3.5)
dist_optim.step()
# Check that parameters match
torch.testing.assert_close(dist_grad_norm, ref_grad_norm)
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_grad_scaler(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
grad_scaler_args = dict(
init_scale=3.21,
growth_factor=1.23,
backoff_factor=0.876,
growth_interval=1,
)
ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, float('inf'), 1, 1, float('nan'), -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
ref_scaler.scale(y_ref).backward(dy.detach())
ref_scaler.step(ref_optim)
ref_scaler.update()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
dist_scaler.scale(y_dist).backward(dy.detach())
dist_scaler.step(dist_optim)
dist_scaler.update()
# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment