"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d49b45b190bc953eb965abd3d70ec30a799f505"
Unverified Commit df493a29 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ci][SDP] extending the test matrix which checks for equivalence with DDP (#542)

parent fa1b85fb
...@@ -34,13 +34,26 @@ _test_fp16_reduction = [False] ...@@ -34,13 +34,26 @@ _test_fp16_reduction = [False]
if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"): if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"):
_test_fp16_reduction.append(True) _test_fp16_reduction.append(True)
_test_amp = [False]
if hasattr(torch.cuda.amp, "autocast"):
_test_amp.append(True)
def _get_mlp(): def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def run_ddp_parity( def run_ddp_parity(
rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction rank,
world_size,
backend,
temp_file_name,
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
): ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
...@@ -51,7 +64,7 @@ def run_ddp_parity( ...@@ -51,7 +64,7 @@ def run_ddp_parity(
NUMBER_BATCHS = 5 NUMBER_BATCHS = 5
BATCH_SIZE = 8 BATCH_SIZE = 8
def check_parity(amp: bool, manual_reduction: bool): def check_parity(manual_reduction: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure # The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
...@@ -108,7 +121,7 @@ def run_ddp_parity( ...@@ -108,7 +121,7 @@ def run_ddp_parity(
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that # The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model) check_same_model_params(sharded_ddp_model, ddp_model)
...@@ -117,35 +130,44 @@ def run_ddp_parity( ...@@ -117,35 +130,44 @@ def run_ddp_parity(
for i in range(NUMBER_BATCHS): for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def closure_ddp(input_tensor=input_tensor): def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def closure_sharded(input_tensor=input_tensor): def sharded_closure(input_tensor=input_tensor):
return closure( return closure(
sharded_ddp_model, sharded_ddp_model,
sharded_ddp_scaler, sharded_scaler,
input_tensor, input_tensor,
grad_accumulation, grad_accumulation,
_manual_reduction=manual_reduction, _manual_reduction=manual_reduction,
) )
# Step/scale both # Step/scale both
if ddp_scaler is not None: for _scaler, _closure, _optimizer in (
_ = closure_ddp(input_tensor) (ddp_scaler, ddp_closure, ddp_optimizer),
ddp_scaler.step(ddp_optimizer) (sharded_scaler, sharded_closure, sharded_optimizer),
ddp_scaler.update() ):
else: if _scaler is not None:
ddp_optimizer.step(closure=closure_ddp) _ = _closure(input_tensor)
_scaler.step(_optimizer)
if sharded_ddp_scaler is not None: _scaler.update()
_ = closure_sharded(input_tensor)
sharded_ddp_scaler.step(sharded_optimizer)
sharded_ddp_scaler.update()
else:
sharded_optimizer.step(closure=closure_sharded)
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if clip_grad_norm:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
assert torch.allclose(
oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth # Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph: if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next( next(sharded_ddp_model.parameters()).requires_grad = not next(
...@@ -155,24 +177,19 @@ def run_ddp_parity( ...@@ -155,24 +177,19 @@ def run_ddp_parity(
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"):
amp_tests.append(True)
manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False] manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
for manual_reduction in manual_reductions: for manual_reduction in manual_reductions:
for amp in amp_tests: print(
print( f"{rank}: Checking configuration: accumulate {grad_accumulation}"
f"Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}"
+ f" - change train graph {change_train_graph}" + f" - amp {amp}"
+ f" - amp {amp}" + f" - manual reduction {manual_reduction}"
+ f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}",
+ f" - buffers {reduce_buffer_size}", flush=True,
flush=True, )
) check_parity(manual_reduction=manual_reduction)
check_parity( torch.cuda.synchronize()
amp=amp, manual_reduction=manual_reduction, torch.distributed.barrier()
)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -183,7 +200,9 @@ def run_ddp_parity( ...@@ -183,7 +200,9 @@ def run_ddp_parity(
@pytest.mark.parametrize("grad_accumulation", [True, False]) @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False]) @pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction) @pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction): @pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp):
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
mp.spawn( mp.spawn(
...@@ -196,6 +215,8 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f ...@@ -196,6 +215,8 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
grad_accumulation, grad_accumulation,
change_train_graph, change_train_graph,
fp16_reduction, fp16_reduction,
clip_grad_norm,
amp,
), ),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment