"src/diffusers/schedulers/scheduling_ddim_parallel.py" did not exist on "a0520193e15951655ee2c08c24bfdca716f6f64c"
Unverified Commit b191fe5f authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[SDP] Adding a unit test which checks for multiple FW passes on the same block (#596)

* Adding a unit test which checks for multiple FW passes on the same block
* Adding an embedding table, but still no problem to show for it
parent e9693976
...@@ -38,9 +38,33 @@ _test_amp = [False] ...@@ -38,9 +38,33 @@ _test_amp = [False]
if hasattr(torch.cuda.amp, "autocast"): if hasattr(torch.cuda.amp, "autocast"):
_test_amp.append(True) _test_amp.append(True)
EMB_SIZE = 32
BATCH_SIZE = 8
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) def _get_mlp_emb(multiple_fw: bool = False):
class MLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.trunk = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3))
self.head = Sequential(Linear(3, 3), Linear(3, 3))
self.multiple_fw = multiple_fw
self.embedding = torch.nn.Embedding(EMB_SIZE, 2)
def forward(self, indices: torch.Tensor) -> torch.Tensor: # type: ignore
inputs = self.embedding(indices)
inputs = self.trunk(inputs) # type: ignore
if self.multiple_fw:
return self.head(self.head(inputs)) # type: ignore
return self.head(inputs) # type: ignore
return MLP()
def _get_random_inputs(device):
return torch.floor(torch.rand((BATCH_SIZE, 2)) * EMB_SIZE).to(dtype=torch.long, device=device)
def run_ddp_parity( def run_ddp_parity(
...@@ -55,6 +79,7 @@ def run_ddp_parity( ...@@ -55,6 +79,7 @@ def run_ddp_parity(
clip_grad_norm, clip_grad_norm,
amp, amp,
manual_reduction, manual_reduction,
multiple_fw,
): ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
...@@ -63,7 +88,6 @@ def run_ddp_parity( ...@@ -63,7 +88,6 @@ def run_ddp_parity(
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
NUMBER_BATCHS = 5 NUMBER_BATCHS = 5
BATCH_SIZE = 8
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
print( print(
...@@ -71,7 +95,8 @@ def run_ddp_parity( ...@@ -71,7 +95,8 @@ def run_ddp_parity(
+ f" - change train graph {change_train_graph}" + f" - change train graph {change_train_graph}"
+ f" - amp {amp}" + f" - amp {amp}"
+ f" - manual reduction {manual_reduction}" + f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}", + f" - buffers {reduce_buffer_size}"
+ f" - multiple FW {multiple_fw}",
flush=True, flush=True,
) )
...@@ -103,7 +128,7 @@ def run_ddp_parity( ...@@ -103,7 +128,7 @@ def run_ddp_parity(
model.reduce() model.reduce()
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = _get_mlp() model = _get_mlp_emb(multiple_fw)
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
...@@ -137,7 +162,7 @@ def run_ddp_parity( ...@@ -137,7 +162,7 @@ def run_ddp_parity(
# Typical training loop, check that we get the exact same results as DDP # Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS): for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) input_tensor = _get_random_inputs(device)
def ddp_closure(input_tensor=input_tensor): def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
...@@ -209,8 +234,16 @@ def run_ddp_parity( ...@@ -209,8 +234,16 @@ def run_ddp_parity(
@pytest.mark.parametrize("clip_grad_norm", [True, False]) @pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp) @pytest.mark.parametrize("amp", _test_amp)
@pytest.mark.parametrize("manual_reduction", [True, False]) @pytest.mark.parametrize("manual_reduction", [True, False])
@pytest.mark.parametrize("multiple_fw", [True, False])
def test_ddp_parity( def test_ddp_parity(
reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
): ):
if manual_reduction and change_train_graph: if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense") pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
...@@ -230,6 +263,7 @@ def test_ddp_parity( ...@@ -230,6 +263,7 @@ def test_ddp_parity(
clip_grad_norm, clip_grad_norm,
amp, amp,
manual_reduction, manual_reduction,
multiple_fw,
), ),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
...@@ -245,7 +279,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b ...@@ -245,7 +279,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
BATCHS = 20 BATCHS = 20
model = _get_mlp() model = _get_mlp_emb()
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
n_half_params = len(list(model.parameters())) // 2 n_half_params = len(list(model.parameters())) // 2
...@@ -273,7 +307,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b ...@@ -273,7 +307,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
) )
for i in range(BATCHS): for i in range(BATCHS):
input_tensor = torch.rand((64, 2)).to(device) input_tensor = _get_random_inputs(device)
# Run DDP # Run DDP
ddp_optimizer.zero_grad() ddp_optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment