Unverified Commit 2e025ab5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Improvements in distributed Adam optimizer for Megatron (#1432)

* Improvements in distributed Adam optimizer for Megatron

Add option to allocate gradient buckets out of one large buffer. Add option to initialize params in user-provided order. Perform communication when saving optimizer state. Support param sync with any dtype.

* Style fixes in distributed Adam helper classes

Review suggestions from @crcrpar
parent fb21698e
from contextlib import contextmanager
import io
import os
import torch
......@@ -25,6 +26,7 @@ def make_models(
num_layers,
size,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
overlap_communication=True,
):
......@@ -61,6 +63,8 @@ def make_models(
],
overlap_grad_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024),
dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**optim_args,
)
......@@ -87,6 +91,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
rtol=None,
atol=None,
......@@ -99,6 +104,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers,
layer_size,
dtype=dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
)
......@@ -172,6 +178,14 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
atol=1e-2,
)
def test_matches_pytorch_allgather_fp16(self):
self.test_matches_pytorch(
dtype=torch.float32,
param_sync_dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
......@@ -277,6 +291,101 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_checkpoint(self):
# Construct two models with same config and different params
num_layers = 5
layer_size = 2
torch.manual_seed(self.seed + self.rank)
_, _, model_save, optim_save = make_models(num_layers, layer_size)
_, _, model_load, optim_load = make_models(num_layers, layer_size)
# Train one of the models
num_steps = 3
micro_batch_steps = 2
batch_size = 4
for step in range(num_steps):
optim_save.zero_grad()
for micro_step in range(micro_batch_steps):
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
y = model_save(x)
y.backward(dy)
optim_save.step()
# Make sure models are different
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
param_load, param_save,
)
# Save state on root rank and load on all ranks
state_dict = {
'model': model_save.state_dict(),
'optim': optim_save.state_dict(),
}
if self.rank == 0:
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes = [state_bytes.getvalue()]
else:
state_bytes = [None]
torch.distributed.broadcast_object_list(state_bytes, src=0)
state_bytes = io.BytesIO(state_bytes[0])
state_dict = torch.load(state_bytes, map_location='cuda')
model_load.load_state_dict(state_dict['model'])
optim_load.load_state_dict(state_dict['optim'])
# Make sure models are identical
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
# Train both models
num_steps = 3
micro_batch_steps = 3
batch_size = 5
for step in range(num_steps):
# Reset gradients
optim_save.zero_grad()
optim_load.zero_grad()
# Forward and backward passes
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
# Forward and backward pass
x_save = x.detach().clone().requires_grad_(True)
y_save = model_save(x_save)
y_save.backward(dy)
x_load = x.detach().clone().requires_grad_(True)
y_load = model_load(x_load)
y_load.backward(dy)
# Check that data tensors match
torch.testing.assert_close(y_load, y_save)
torch.testing.assert_close(x_load.grad, x_save.grad)
# Optimizer step
optim_save.step()
optim_load.step()
# Check that parameters match
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment