Unverified Commit 138b2033 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Check ShardedDDP / DDP parity + bugfix (#242)

* unit test checking ddp and sharded_ddp equivalence, reproducing the issue that Sean spotted
* fixing the issue, not counting requests in flight properly
* adding a multiple optimizers case
parent 6afbe677
......@@ -74,7 +74,7 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17,
**default: Any,
):
logging.warning("Disabling bucketing for now, error prone for some models")
# logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups
......@@ -105,8 +105,9 @@ class OSS(Optimizer):
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {}
self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self.buckets[device] = [
......@@ -597,11 +598,11 @@ class OSS(Optimizer):
for dst_rank, params in enumerate(per_rank_params):
offset = 0
for param in params:
# Only consider the params which will require a gradient
for param in filter(lambda p: p.requires_grad, params):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
# - param not the first one in the DAG, because this may be kicked out of autograd (depending on inputs)
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size and param.is_leaf:
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param[param] = True
offset += param.numel()
else:
......@@ -613,5 +614,9 @@ class OSS(Optimizer):
self.buckets[device][dst_rank].global_rank = self.global_rank
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1
# - all the direct reduce/broadcast
self._max_work_handles = sum(not value for value in self.should_bucket_param.values())
# - if we're bucketing, this means more work handles: one per rank and per device
if self.bucket_size > 0:
self._max_work_handles += len(self.per_device_params.keys()) * self.world_size
......@@ -7,6 +7,7 @@
Testing OssDdp class.
"""
import copy
import tempfile
from typing import List
......@@ -16,6 +17,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
......@@ -27,16 +29,6 @@ from contextlib import suppress
from fairscale.utils.testing import GPT2
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
def run_one_step(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......@@ -123,6 +115,153 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
def run_ddp_parity(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
_ = ddp_optimizer.step(closure=closure_ddp)
_ = sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank) # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
n_half_params = len(list(model.parameters())) // 2
sharded_optimizer = OSS(
params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_optimizer_2 = OSS(
params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
check_same_model_params() # The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
# Run DDP
ddp_optimizer.zero_grad()
ddp_optimizer_2.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
ddp_optimizer.step()
ddp_optimizer_2.step()
# Run Sharded
sharded_optimizer.zero_grad()
sharded_optimizer_2.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
sharded_optimizer.step()
sharded_optimizer_2.step()
check_same_model_params()
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity_two_optim():
temp_file_name = tempfile.mkstemp()[1]
world_size = 2
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity_two_optim, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment