Unverified Commit 138b2033 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Check ShardedDDP / DDP parity + bugfix (#242)

* unit test checking ddp and sharded_ddp equivalence, reproducing the issue that Sean spotted
* fixing the issue, not counting requests in flight properly
* adding a multiple optimizers case
parent 6afbe677
...@@ -74,7 +74,7 @@ class OSS(Optimizer): ...@@ -74,7 +74,7 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17, broadcast_buffer_size: int = 2 ** 17,
**default: Any, **default: Any,
): ):
logging.warning("Disabling bucketing for now, error prone for some models") # logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0 broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups # Hold all the model params in the root .param_groups
...@@ -105,8 +105,9 @@ class OSS(Optimizer): ...@@ -105,8 +105,9 @@ class OSS(Optimizer):
self._all_states: List[Dict[str, Any]] = [] self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {} self.buckets: Dict[torch.device, List[Bucket]] = {}
self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters # Allocate one buffer per rank and per device to group the small parameters
self.buckets[device] = [ self.buckets[device] = [
...@@ -597,11 +598,11 @@ class OSS(Optimizer): ...@@ -597,11 +598,11 @@ class OSS(Optimizer):
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
offset = 0 offset = 0
for param in params: # Only consider the params which will require a gradient
for param in filter(lambda p: p.requires_grad, params):
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
# - param not the first one in the DAG, because this may be kicked out of autograd (depending on inputs) if (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size and param.is_leaf:
self.should_bucket_param[param] = True self.should_bucket_param[param] = True
offset += param.numel() offset += param.numel()
else: else:
...@@ -613,5 +614,9 @@ class OSS(Optimizer): ...@@ -613,5 +614,9 @@ class OSS(Optimizer):
self.buckets[device][dst_rank].global_rank = self.global_rank self.buckets[device][dst_rank].global_rank = self.global_rank
# Determine the max work handles in flight: # Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket # - all the direct reduce/broadcast
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1 self._max_work_handles = sum(not value for value in self.should_bucket_param.values())
# - if we're bucketing, this means more work handles: one per rank and per device
if self.bucket_size > 0:
self._max_work_handles += len(self.per_device_params.keys()) * self.world_size
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
Testing OssDdp class. Testing OssDdp class.
""" """
import copy
import tempfile import tempfile
from typing import List from typing import List
...@@ -16,6 +17,7 @@ import torch ...@@ -16,6 +17,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
...@@ -27,16 +29,6 @@ from contextlib import suppress ...@@ -27,16 +29,6 @@ from contextlib import suppress
from fairscale.utils.testing import GPT2 from fairscale.utils.testing import GPT2
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
def run_one_step(rank, world_size, backend, device, temp_file_name): def run_one_step(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
...@@ -123,6 +115,153 @@ def run_test(backend, device, world_size=2): ...@@ -123,6 +115,153 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
def run_ddp_parity(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
_ = ddp_optimizer.step(closure=closure_ddp)
_ = sharded_optimizer.step(closure=closure_sharded)
check_same_model_params()
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank) # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
n_half_params = len(list(model.parameters())) // 2
sharded_optimizer = OSS(
params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_optimizer_2 = OSS(
params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
check_same_model_params() # The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
# Run DDP
ddp_optimizer.zero_grad()
ddp_optimizer_2.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
ddp_optimizer.step()
ddp_optimizer_2.step()
# Run Sharded
sharded_optimizer.zero_grad()
sharded_optimizer_2.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
sharded_optimizer.step()
sharded_optimizer_2.step()
check_same_model_params()
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity_two_optim():
temp_file_name = tempfile.mkstemp()[1]
world_size = 2
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity_two_optim, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name): def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment