Unverified Commit cd499737 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add features to distributed Adam for Megatron support (#1414)

* Add features to distributed Adam for Megatron support

Support gradient clipping, gradient scaling, FP32 grad accumulation, and multiple dtypes and devices.

* Restore closure arg to distributed Adam

Review suggestion from @crcrpar
parent bf3c008e
...@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module): ...@@ -21,11 +21,17 @@ class SimpleModel(torch.nn.Module):
y += (i+1) * l(x) y += (i+1) * l(x)
return y return y
def make_models(num_layers, size, dtype=torch.float32, overlap_communication=True): def make_models(
num_layers,
size,
dtype=torch.float32,
device='cuda',
overlap_communication=True,
):
# Construct models with same parameters # Construct models with same parameters
ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda') ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device='cuda') dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device)
with torch.no_grad(): with torch.no_grad():
for ref_param, dist_param in zip(dist_model.parameters(), for ref_param, dist_param in zip(dist_model.parameters(),
ref_model.parameters()): ref_model.parameters()):
...@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru ...@@ -35,22 +41,22 @@ def make_models(num_layers, size, dtype=torch.float32, overlap_communication=Tru
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
ref_model = torch.nn.parallel.DistributedDataParallel( ref_model = torch.nn.parallel.DistributedDataParallel(
ref_model, ref_model,
device_ids=[rank], device_ids=[rank] if device=='cuda' else None,
output_device=rank, output_device=rank if device=='cuda' else None,
) )
# Construct optimizers with same hyperparameters # Construct optimizers with same hyperparameters
optim_args = { 'lr': 1, 'betas': (0.1,0.2), 'eps': 0.1, 'weight_decay': 0.1 } optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1)
ref_optim = torch.optim.AdamW( ref_optim = torch.optim.AdamW(
[ [
{'params': list(ref_model.parameters())[1::2], 'lr': 0.5}, {'params': list(ref_model.parameters())[1::2], 'lr': 0.2},
{'params': list(ref_model.parameters())[0::2]}, {'params': list(ref_model.parameters())[0::2]},
], ],
**optim_args, **optim_args,
) )
dist_optim = DistributedFusedAdam( dist_optim = DistributedFusedAdam(
[ [
{'params': list(dist_model.parameters())[1::2], 'lr': 0.5}, {'params': list(dist_model.parameters())[1::2], 'lr': 0.2},
{'params': list(dist_model.parameters())[0::2]}, {'params': list(dist_model.parameters())[0::2]},
], ],
overlap_grad_sync=overlap_communication, overlap_grad_sync=overlap_communication,
...@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -81,8 +87,9 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True, overlap_communication=True,
use_nosync=True, use_nosync=True,
dtype=torch.float32, dtype=torch.float32,
rtol=1e-5, device='cuda',
atol=1e-5, rtol=None,
atol=None,
): ):
torch.manual_seed(self.seed + self.rank) torch.manual_seed(self.seed + self.rank)
...@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -92,6 +99,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers, num_layers,
layer_size, layer_size,
dtype=dtype, dtype=dtype,
device=device,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
) )
...@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -106,10 +114,10 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
for micro_step in range(micro_batch_steps): for micro_step in range(micro_batch_steps):
# Synthetic data # Synthetic data
x = torch.rand(batch_size, layer_size) + 0.5 x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) + 0.5 dy = torch.rand_like(x) - 0.5
x = x.to(dtype=dtype, device='cuda') x = x.to(dtype=dtype, device=device)
dy = dy.to(dtype=dtype, device='cuda') dy = dy.to(dtype=dtype, device=device)
# Reference implementation # Reference implementation
x_ref = x.detach().clone().requires_grad_(True) x_ref = x.detach().clone().requires_grad_(True)
...@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -136,8 +144,8 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step() dist_optim.step()
# Check that parameters match # Check that parameters match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(), for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters())): dist_model.parameters()):
torch.testing.assert_close( torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol) dist_param, ref_param, rtol=rtol, atol=atol)
...@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -150,6 +158,20 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
def test_matches_pytorch_sync_every_step(self): def test_matches_pytorch_sync_every_step(self):
self.test_matches_pytorch(use_nosync=False) self.test_matches_pytorch(use_nosync=False)
def test_matches_pytorch_fp64(self):
self.test_matches_pytorch(
dtype=torch.float64,
rtol=1.3e-6,
atol=1e-5,
)
def test_matches_pytorch_fp16(self):
self.test_matches_pytorch(
dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self): def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank) torch.manual_seed(self.seed + self.rank)
...@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): ...@@ -172,16 +194,89 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_optim.step() dist_optim.step()
# Check that parameters do not match # Check that parameters do not match
for i, (ref_param, dist_param) in enumerate(zip(ref_model.parameters(), for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters())): dist_model.parameters()):
self.assertRaises( self.assertRaises(
AssertionError, AssertionError,
torch.testing.assert_close, torch.testing.assert_close,
dist_param, ref_param, dist_param, ref_param,
rtol=1e-5,
atol=1e-5,
) )
def test_clip_grad_norm(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, -1, 1, -1, 1, -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
y_ref.backward(dy.detach())
ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5)
ref_optim.step()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
y_dist.backward(dy.detach())
dist_grad_norm = dist_optim.clip_grad_norm(3.5)
dist_optim.step()
# Check that parameters match
torch.testing.assert_close(dist_grad_norm, ref_grad_norm)
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_grad_scaler(self):
torch.manual_seed(self.seed + self.rank)
# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)
grad_scaler_args = dict(
init_scale=3.21,
growth_factor=1.23,
backoff_factor=0.876,
growth_interval=1,
)
ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args)
# Training steps with pre-determined gradients
xs = [3, 1, 4, 1, 5, 9]
dys = [1, float('inf'), 1, 1, float('nan'), -1]
for x, dy in zip(xs, dys):
x = torch.tensor([x], dtype=torch.float32, device='cuda')
dy = torch.tensor([dy], dtype=torch.float32, device='cuda')
# Reference implementation
ref_optim.zero_grad()
y_ref = ref_model(x.detach())
ref_scaler.scale(y_ref).backward(dy.detach())
ref_scaler.step(ref_optim)
ref_scaler.update()
# Distributed implementation
dist_optim.zero_grad()
y_dist = dist_model(x.detach())
dist_scaler.scale(y_dist).backward(dy.detach())
dist_scaler.step(dist_optim)
dist_scaler.update()
# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
if __name__ == "__main__": if __name__ == "__main__":
# Assume script has been run with torchrun # Assume script has been run with torchrun
common_utils.run_tests() common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment