Unverified Commit 5c3ff9bd authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDDP : Adding a proper DDP parity / AMP unit test, overdue (#361)

* Adding a proper ddp parity / AMP unit test, overdue
* catch non-AMP pytorch
parent e3a20fef
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" """
Testing OssDdp class. Testing ShardedDDP
""" """
from contextlib import suppress from contextlib import suppress
...@@ -14,6 +14,7 @@ from typing import List ...@@ -14,6 +14,7 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
...@@ -21,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
...@@ -132,52 +134,90 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -132,52 +134,90 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
# Any model works. Add one different buffer per rank def check_parity(amp: bool):
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) # Any model works. Add one different buffer per rank
model.register_buffer("test_buffer", torch.ones((1)) * rank) model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.to(device) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
# The model should be synchronized in between the ranks at construction time, check that sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
check_same_model_params() sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
)
# The models should stay the same in between the ranks ddp_model_single = copy.deepcopy(model)
for i in range(20): ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
input_tensor = torch.rand((64, 2)).to(device) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def closure_ddp(input_tensor=input_tensor): ddp_scaler = TorchGradScaler() if amp else None
ddp_optimizer.zero_grad() sharded_ddp_scaler = ShardedGradScaler() if amp else None
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor): def check_same_model_params():
sharded_optimizer.zero_grad() for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
sharded_loss = sharded_ddp_model(input_tensor).abs().sum() for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
sharded_loss.backward() assert torch.allclose(
return sharded_loss p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
_ = ddp_optimizer.step(closure=closure_ddp) for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
_ = sharded_optimizer.step(closure=closure_sharded) assert torch.allclose(
b, ddp_b, atol=1e-3
), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params() check_same_model_params()
# The models should stay the same in between the ranks
for i in range(10):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
if ddp_scaler is not None:
with torch.cuda.amp.autocast():
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_scaler.scale(ddp_loss).backward()
else:
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
if sharded_ddp_scaler is not None:
with torch.cuda.amp.autocast():
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_ddp_scaler.scale(sharded_loss).backward()
else:
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
# Step/scale both
if ddp_scaler is not None:
_ = closure_ddp(input_tensor)
ddp_scaler.step(ddp_optimizer)
ddp_scaler.update()
else:
ddp_optimizer.step(closure=closure_ddp)
if sharded_ddp_scaler is not None:
_ = closure_sharded(input_tensor)
sharded_ddp_scaler.step(sharded_optimizer)
sharded_ddp_scaler.update()
else:
sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
check_parity(amp=False)
# Catch a version of pytorch which would not support AMP
if hasattr(torch.cuda.amp, "autocast"):
check_parity(amp=True)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment