Unverified Commit c9fdf506 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] Adding a hard sync stream barrier before broadcast (#512)

* Adding a hard sync barrier before the broadcast, mostly useful for Gloo actually, NCCL is synced behind the scene
* adding a proper unit test
* adding a unit test for https://github.com/facebookresearch/fairscale/pull/510
parent 1fa778d7
......@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
## [0.3.1] - 2021-03-09
### Added
......
......@@ -94,6 +94,7 @@ class OSS(Optimizer):
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.backend = dist.get_backend(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
......@@ -546,6 +547,12 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
for device in self.per_device_params.keys():
torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback
for device in self.buckets.keys():
......@@ -558,7 +565,10 @@ class OSS(Optimizer):
)
# Only check on the last handle, they're all inlined on the same CUDA stream
_ = list(filter(lambda x: x.wait(), work_handles))
if work_handles and self.backend == dist.Backend.NCCL:
work_handles[-1].wait()
else:
_ = list(filter(lambda x: x.wait(), work_handles))
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
......
......@@ -548,3 +548,27 @@ class DummyProcessGroup:
def size(self) -> int:
return self._size
class SGDWithPausingCompute(torch.optim.SGD):
def __init__(self, *args, **kwargs) -> None: # type: ignore
self.rank = kwargs["rank"]
del kwargs["rank"]
super().__init__(*args, **kwargs)
def step(self, closure: Optional[Any] = None) -> Any:
loss = super().step(closure=closure)
# This is used to make sure that OSS and ShardedDDP enforce a proper stream synchronization
# - Add a long cuda wait on a compute stream, non blocking from the CPU perspective
with torch.cuda.stream(torch.cuda.Stream()):
torch.cuda._sleep(100000000)
# - optionally change the params on a per rank basis
with torch.no_grad():
for param_group in self.param_groups:
for param in param_group["params"]:
param *= 1.0 + self.rank / 10.0
return loss
......@@ -21,11 +21,11 @@ from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.utils.testing import (
GPT2,
SGDWithPausingCompute,
available_devices,
check_same_models_across_ranks,
skip_if_less_than_four_gpu,
skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu,
)
......@@ -46,7 +46,15 @@ class _DoubleInput(torch.nn.Module):
def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
rank,
world_size,
backend,
device,
temp_file_name,
broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
......@@ -62,7 +70,11 @@ def run_one_step(
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
optimizer_settings = {"lr": 1e-3, "momentum": 0.99}
if optimizer_type == SGDWithPausingCompute:
optimizer_settings["rank"] = rank
optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
)
......@@ -85,6 +97,11 @@ def run_one_step(
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# For a sync of all the streams
if device.type == torch.device("cuda").type:
torch.cuda.synchronize(device=device)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or not grad_accumulation
check_same_models_across_ranks(
......@@ -94,7 +111,7 @@ def run_one_step(
dist.destroy_process_group()
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_one_step,
......@@ -109,21 +126,33 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
@pytest.mark.parametrize(
"setup",
[
[dist.Backend.NCCL, torch.device("cuda")],
[dist.Backend.GLOO, torch.device("cpu")],
[dist.Backend.GLOO, torch.device("cuda")],
],
)
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, setup):
world_size = 2
run_test(
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
temp_file_name = tempfile.mkstemp()[1]
@skip_if_py38
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
mp.spawn(
run_one_step,
args=(
world_size,
setup[0],
setup[1],
temp_file_name,
broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
),
nprocs=world_size,
join=True,
)
......
......@@ -443,6 +443,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
# Check that the states are not None, but {}
for state in optimizer.state.values():
for _, _ in state.items():
pass
dist.destroy_process_group()
......@@ -792,7 +797,7 @@ def test_state_dict_distributed():
)
def run_ddp_parity(rank, world_size, backend, temp_file_name):
def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......@@ -910,16 +915,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
check_step()
for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt, change_train_graph=False)
check_optimizer_equivalence(opt, change_train_graph=True)
check_optimizer_equivalence(opt, change_train_graph=change_train_graph)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
@pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("backend", [dist.Backend.NCCL, dist.Backend.GLOO])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend):
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_ddp_parity, args=(world_size, backend, temp_file_name, change_train_graph), nprocs=world_size, join=True
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment