Unverified Commit c9fdf506 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] Adding a hard sync stream barrier before broadcast (#512)

* Adding a hard sync barrier before the broadcast, mostly useful for Gloo actually, NCCL is synced behind the scene
* adding a proper unit test
* adding a unit test for https://github.com/facebookresearch/fairscale/pull/510
parent 1fa778d7
...@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
### Fixed ### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
## [0.3.1] - 2021-03-09 ## [0.3.1] - 2021-03-09
### Added ### Added
......
...@@ -94,6 +94,7 @@ class OSS(Optimizer): ...@@ -94,6 +94,7 @@ class OSS(Optimizer):
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.backend = dist.get_backend(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)] self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
...@@ -546,6 +547,12 @@ class OSS(Optimizer): ...@@ -546,6 +547,12 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
for device in self.per_device_params.keys():
torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback work_handles = [] # Work handles are consumed within this scope, no callback
for device in self.buckets.keys(): for device in self.buckets.keys():
...@@ -558,6 +565,9 @@ class OSS(Optimizer): ...@@ -558,6 +565,9 @@ class OSS(Optimizer):
) )
# Only check on the last handle, they're all inlined on the same CUDA stream # Only check on the last handle, they're all inlined on the same CUDA stream
if work_handles and self.backend == dist.Backend.NCCL:
work_handles[-1].wait()
else:
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
......
...@@ -548,3 +548,27 @@ class DummyProcessGroup: ...@@ -548,3 +548,27 @@ class DummyProcessGroup:
def size(self) -> int: def size(self) -> int:
return self._size return self._size
class SGDWithPausingCompute(torch.optim.SGD):
def __init__(self, *args, **kwargs) -> None: # type: ignore
self.rank = kwargs["rank"]
del kwargs["rank"]
super().__init__(*args, **kwargs)
def step(self, closure: Optional[Any] = None) -> Any:
loss = super().step(closure=closure)
# This is used to make sure that OSS and ShardedDDP enforce a proper stream synchronization
# - Add a long cuda wait on a compute stream, non blocking from the CPU perspective
with torch.cuda.stream(torch.cuda.Stream()):
torch.cuda._sleep(100000000)
# - optionally change the params on a per rank basis
with torch.no_grad():
for param_group in self.param_groups:
for param in param_group["params"]:
param *= 1.0 + self.rank / 10.0
return loss
...@@ -21,11 +21,11 @@ from fairscale.nn.data_parallel import ShardedDataParallel ...@@ -21,11 +21,11 @@ from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.utils.testing import ( from fairscale.utils.testing import (
GPT2, GPT2,
SGDWithPausingCompute,
available_devices, available_devices,
check_same_models_across_ranks, check_same_models_across_ranks,
skip_if_less_than_four_gpu, skip_if_less_than_four_gpu,
skip_if_no_cuda, skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu, skip_if_single_gpu,
) )
...@@ -46,7 +46,15 @@ class _DoubleInput(torch.nn.Module): ...@@ -46,7 +46,15 @@ class _DoubleInput(torch.nn.Module):
def run_one_step( def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size, rank,
world_size,
backend,
device,
temp_file_name,
broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
): ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"): if device == torch.device("cuda"):
...@@ -62,7 +70,11 @@ def run_one_step( ...@@ -62,7 +70,11 @@ def run_one_step(
next(model.parameters()).requires_grad = False # Test non-trainable parameters next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer_settings = {"lr": 1e-3, "momentum": 0.99}
if optimizer_type == SGDWithPausingCompute:
optimizer_settings["rank"] = rank
optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
ddp_model = ShardedDataParallel( ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
) )
...@@ -85,6 +97,11 @@ def run_one_step( ...@@ -85,6 +97,11 @@ def run_one_step(
# The models should stay the same in between the ranks # The models should stay the same in between the ranks
for i in range(5): for i in range(5):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
# For a sync of all the streams
if device.type == torch.device("cuda").type:
torch.cuda.synchronize(device=device)
# when running on cpu/gloo the "nodes" are not really different # when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or not grad_accumulation same_params = device == torch.device("cpu") or not grad_accumulation
check_same_models_across_ranks( check_same_models_across_ranks(
...@@ -94,7 +111,7 @@ def run_one_step( ...@@ -94,7 +111,7 @@ def run_one_step(
dist.destroy_process_group() dist.destroy_process_group()
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size): def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn( mp.spawn(
run_one_step, run_one_step,
...@@ -109,21 +126,33 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, ...@@ -109,21 +126,33 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@pytest.mark.parametrize("broadcast_buffers", [True, False]) @pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False]) @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20]) @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size): @pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
@pytest.mark.parametrize(
"setup",
[
[dist.Backend.NCCL, torch.device("cuda")],
[dist.Backend.GLOO, torch.device("cpu")],
[dist.Backend.GLOO, torch.device("cuda")],
],
)
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, setup):
world_size = 2 world_size = 2
run_test( temp_file_name = tempfile.mkstemp()[1]
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
@skip_if_py38 mp.spawn(
@pytest.mark.parametrize("broadcast_buffers", [True, False]) run_one_step,
@pytest.mark.parametrize("grad_accumulation", [True, False]) args=(
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20]) world_size,
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size): setup[0],
world_size = 2 setup[1],
run_test( temp_file_name,
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size broadcast_buffers,
grad_accumulation,
reduce_buffer_size,
optimizer_type,
),
nprocs=world_size,
join=True,
) )
......
...@@ -443,6 +443,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -443,6 +443,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
# Check that the states are not None, but {}
for state in optimizer.state.values():
for _, _ in state.items():
pass
dist.destroy_process_group() dist.destroy_process_group()
...@@ -792,7 +797,7 @@ def test_state_dict_distributed(): ...@@ -792,7 +797,7 @@ def test_state_dict_distributed():
) )
def run_ddp_parity(rank, world_size, backend, temp_file_name): def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
...@@ -910,16 +915,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -910,16 +915,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
check_step() check_step()
for opt in [torch.optim.Adam, torch.optim.SGD]: for opt in [torch.optim.Adam, torch.optim.SGD]:
check_optimizer_equivalence(opt, change_train_graph=False) check_optimizer_equivalence(opt, change_train_graph=change_train_graph)
check_optimizer_equivalence(opt, change_train_graph=True)
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_ddp_parity(): @pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("backend", [dist.Backend.NCCL, dist.Backend.GLOO])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL mp.spawn(
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True) run_ddp_parity, args=(world_size, backend, temp_file_name, change_train_graph), nprocs=world_size, join=True
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment