Unverified Commit b191fe5f authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[SDP] Adding a unit test which checks for multiple FW passes on the same block (#596)

* Adding a unit test which checks for multiple FW passes on the same block
* Adding an embedding table, but still no problem to show for it
parent e9693976
......@@ -38,9 +38,33 @@ _test_amp = [False]
if hasattr(torch.cuda.amp, "autocast"):
_test_amp.append(True)
EMB_SIZE = 32
BATCH_SIZE = 8
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def _get_mlp_emb(multiple_fw: bool = False):
class MLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.trunk = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3))
self.head = Sequential(Linear(3, 3), Linear(3, 3))
self.multiple_fw = multiple_fw
self.embedding = torch.nn.Embedding(EMB_SIZE, 2)
def forward(self, indices: torch.Tensor) -> torch.Tensor: # type: ignore
inputs = self.embedding(indices)
inputs = self.trunk(inputs) # type: ignore
if self.multiple_fw:
return self.head(self.head(inputs)) # type: ignore
return self.head(inputs) # type: ignore
return MLP()
def _get_random_inputs(device):
return torch.floor(torch.rand((BATCH_SIZE, 2)) * EMB_SIZE).to(dtype=torch.long, device=device)
def run_ddp_parity(
......@@ -55,6 +79,7 @@ def run_ddp_parity(
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......@@ -63,7 +88,6 @@ def run_ddp_parity(
torch.manual_seed(rank)
np.random.seed(rank)
NUMBER_BATCHS = 5
BATCH_SIZE = 8
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
print(
......@@ -71,7 +95,8 @@ def run_ddp_parity(
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
+ f" - buffers {reduce_buffer_size}"
+ f" - multiple FW {multiple_fw}",
flush=True,
)
......@@ -103,7 +128,7 @@ def run_ddp_parity(
model.reduce()
# Any model works. Add one different buffer per rank
model = _get_mlp()
model = _get_mlp_emb(multiple_fw)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
......@@ -137,7 +162,7 @@ def run_ddp_parity(
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
input_tensor = _get_random_inputs(device)
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
......@@ -209,8 +234,16 @@ def run_ddp_parity(
@pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp)
@pytest.mark.parametrize("manual_reduction", [True, False])
@pytest.mark.parametrize("multiple_fw", [True, False])
def test_ddp_parity(
reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction
reduce_buffer_size,
grad_accumulation,
change_train_graph,
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
):
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
......@@ -230,6 +263,7 @@ def test_ddp_parity(
clip_grad_norm,
amp,
manual_reduction,
multiple_fw,
),
nprocs=world_size,
join=True,
......@@ -245,7 +279,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
BATCHS = 20
model = _get_mlp()
model = _get_mlp_emb()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
n_half_params = len(list(model.parameters())) // 2
......@@ -273,7 +307,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
)
for i in range(BATCHS):
input_tensor = torch.rand((64, 2)).to(device)
input_tensor = _get_random_inputs(device)
# Run DDP
ddp_optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment