Unverified Commit 60c8de4a authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

[feature] OSS: add unit test for distributed checkpointing (#273)

author: Joshua Meier
parent b640cab5
...@@ -565,3 +565,101 @@ def test_gradient_clipping(): ...@@ -565,3 +565,101 @@ def test_gradient_clipping():
mp.spawn( mp.spawn(
run_gradient_clipping, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_gradient_clipping, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
def run_state_dict_distributed(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
device = torch.device(rank)
torch.manual_seed(rank) # make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model_oss1 = torch.nn.Sequential(
torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, target_width),
).to(device)
model_oss2 = copy.deepcopy(model_oss1)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# gradient norm computation from OSS and adds a dependency.
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1 = DDP(module=model_oss1, device_ids=[rank],)
sharded_optimizer1 = optim.OSS(model_oss1.parameters(), lr=0.1, momentum=0.99)
model_oss2 = DDP(module=model_oss2, device_ids=[rank],)
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
def run_grad_step(device, model, optimizer):
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
model.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before saving)"
# save the state dict for one model only
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# Check that the pulled state and the .param_groups attribute are in sync
for replica in range(len(state_dict2["param_groups"])):
for k in state_dict2["param_groups"][replica].keys():
if k != "params":
assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k]
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
# check that saving did not cause a change in the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(
param1, param2
), "parameters of the two identical models have diverged (after consolidating)"
# save again
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
# reload the state_dict
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.load_state_dict(state_dict2)
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
# check that reloading a saved state dict does not change the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (after reloading)"
dist.destroy_process_group()
@skip_if_no_cuda
def test_state_dict_distributed():
world_size = 8
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
mp.spawn(
run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment