You need to sign in or sign up before continuing.
Unverified Commit 2e025ab5 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Improvements in distributed Adam optimizer for Megatron (#1432)

* Improvements in distributed Adam optimizer for Megatron

Add option to allocate gradient buckets out of one large buffer. Add option to initialize params in user-provided order. Perform communication when saving optimizer state. Support param sync with any dtype.

* Style fixes in distributed Adam helper classes

Review suggestions from @crcrpar
parent fb21698e
from contextlib import contextmanager
import io
import os
import torch
......@@ -25,6 +26,7 @@ def make_models(
num_layers,
size,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
overlap_communication=True,
):
......@@ -61,6 +63,8 @@ def make_models(
],
overlap_grad_sync=overlap_communication,
bucket_cap_mb=71/(4*1024*1024),
dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**optim_args,
)
......@@ -87,6 +91,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
overlap_communication=True,
use_nosync=True,
dtype=torch.float32,
param_sync_dtype=None,
device='cuda',
rtol=None,
atol=None,
......@@ -99,6 +104,7 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
num_layers,
layer_size,
dtype=dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
)
......@@ -172,6 +178,14 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
atol=1e-2,
)
def test_matches_pytorch_allgather_fp16(self):
self.test_matches_pytorch(
dtype=torch.float32,
param_sync_dtype=torch.float16,
rtol=1e-2,
atol=1e-2,
)
def test_raises_on_mismatch(self):
torch.manual_seed(self.seed + self.rank)
......@@ -277,6 +291,101 @@ class TestDistributedFusedAdam(NcclDistributedTestBase):
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)
def test_checkpoint(self):
# Construct two models with same config and different params
num_layers = 5
layer_size = 2
torch.manual_seed(self.seed + self.rank)
_, _, model_save, optim_save = make_models(num_layers, layer_size)
_, _, model_load, optim_load = make_models(num_layers, layer_size)
# Train one of the models
num_steps = 3
micro_batch_steps = 2
batch_size = 4
for step in range(num_steps):
optim_save.zero_grad()
for micro_step in range(micro_batch_steps):
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
y = model_save(x)
y.backward(dy)
optim_save.step()
# Make sure models are different
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
param_load, param_save,
)
# Save state on root rank and load on all ranks
state_dict = {
'model': model_save.state_dict(),
'optim': optim_save.state_dict(),
}
if self.rank == 0:
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes = [state_bytes.getvalue()]
else:
state_bytes = [None]
torch.distributed.broadcast_object_list(state_bytes, src=0)
state_bytes = io.BytesIO(state_bytes[0])
state_dict = torch.load(state_bytes, map_location='cuda')
model_load.load_state_dict(state_dict['model'])
optim_load.load_state_dict(state_dict['optim'])
# Make sure models are identical
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
# Train both models
num_steps = 3
micro_batch_steps = 3
batch_size = 5
for step in range(num_steps):
# Reset gradients
optim_save.zero_grad()
optim_load.zero_grad()
# Forward and backward passes
for micro_step in range(micro_batch_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.cuda()
dy = dy.cuda()
# Forward and backward pass
x_save = x.detach().clone().requires_grad_(True)
y_save = model_save(x_save)
y_save.backward(dy)
x_load = x.detach().clone().requires_grad_(True)
y_load = model_load(x_load)
y_load.backward(dy)
# Check that data tensors match
torch.testing.assert_close(y_load, y_save)
torch.testing.assert_close(x_load.grad, x_save.grad)
# Optimizer step
optim_save.step()
optim_load.step()
# Check that parameters match
for param_save, param_load in zip(model_save.parameters(),
model_load.parameters()):
torch.testing.assert_close(param_load, param_save)
if __name__ == "__main__":
# Assume script has been run with torchrun
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment