Unverified Commit df493a29 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ci][SDP] extending the test matrix which checks for equivalence with DDP (#542)

parent fa1b85fb
......@@ -34,13 +34,26 @@ _test_fp16_reduction = [False]
if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"):
_test_fp16_reduction.append(True)
_test_amp = [False]
if hasattr(torch.cuda.amp, "autocast"):
_test_amp.append(True)
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def run_ddp_parity(
rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction
rank,
world_size,
backend,
temp_file_name,
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......@@ -51,7 +64,7 @@ def run_ddp_parity(
NUMBER_BATCHS = 5
BATCH_SIZE = 8
def check_parity(amp: bool, manual_reduction: bool):
def check_parity(manual_reduction: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
......@@ -108,7 +121,7 @@ def run_ddp_parity(
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None
sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model)
......@@ -117,35 +130,44 @@ def run_ddp_parity(
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def closure_sharded(input_tensor=input_tensor):
def sharded_closure(input_tensor=input_tensor):
return closure(
sharded_ddp_model,
sharded_ddp_scaler,
sharded_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
)
# Step/scale both
if ddp_scaler is not None:
_ = closure_ddp(input_tensor)
ddp_scaler.step(ddp_optimizer)
ddp_scaler.update()
else:
ddp_optimizer.step(closure=closure_ddp)
if sharded_ddp_scaler is not None:
_ = closure_sharded(input_tensor)
sharded_ddp_scaler.step(sharded_optimizer)
sharded_ddp_scaler.update()
else:
sharded_optimizer.step(closure=closure_sharded)
for _scaler, _closure, _optimizer in (
(ddp_scaler, ddp_closure, ddp_optimizer),
(sharded_scaler, sharded_closure, sharded_optimizer),
):
if _scaler is not None:
_ = _closure(input_tensor)
_scaler.step(_optimizer)
_scaler.update()
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if clip_grad_norm:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
assert torch.allclose(
oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
......@@ -155,24 +177,19 @@ def run_ddp_parity(
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"):
amp_tests.append(True)
manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
for manual_reduction in manual_reductions:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
check_parity(
amp=amp, manual_reduction=manual_reduction,
)
print(
f"{rank}: Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
check_parity(manual_reduction=manual_reduction)
torch.cuda.synchronize()
torch.distributed.barrier()
dist.destroy_process_group()
......@@ -183,7 +200,9 @@ def run_ddp_parity(
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction):
@pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp):
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(
......@@ -196,6 +215,8 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
),
nprocs=world_size,
join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment