Unverified Commit 5c3ff9bd authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDDP : Adding a proper DDP parity / AMP unit test, overdue (#361)

* Adding a proper ddp parity / AMP unit test, overdue
* catch non-AMP pytorch
parent e3a20fef
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""
Testing OssDdp class.
Testing ShardedDDP
"""
from contextlib import suppress
......@@ -14,6 +14,7 @@ from typing import List
import numpy as np
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
......@@ -21,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
......@@ -132,52 +134,90 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.manual_seed(rank)
np.random.seed(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
def check_parity(amp: bool):
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
)
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
ddp_scaler = TorchGradScaler() if amp else None
sharded_ddp_scaler = ShardedGradScaler() if amp else None
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
_ = ddp_optimizer.step(closure=closure_ddp)
_ = sharded_optimizer.step(closure=closure_sharded)
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b, atol=1e-3
), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(10):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
if ddp_scaler is not None:
with torch.cuda.amp.autocast():
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_scaler.scale(ddp_loss).backward()
else:
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
if sharded_ddp_scaler is not None:
with torch.cuda.amp.autocast():
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_ddp_scaler.scale(sharded_loss).backward()
else:
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
# Step/scale both
if ddp_scaler is not None:
_ = closure_ddp(input_tensor)
ddp_scaler.step(ddp_optimizer)
ddp_scaler.update()
else:
ddp_optimizer.step(closure=closure_ddp)
if sharded_ddp_scaler is not None:
_ = closure_sharded(input_tensor)
sharded_ddp_scaler.step(sharded_optimizer)
sharded_ddp_scaler.update()
else:
sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
check_parity(amp=False)
# Catch a version of pytorch which would not support AMP
if hasattr(torch.cuda.amp, "autocast"):
check_parity(amp=True)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment