Unverified Commit 205af8c2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][ShardedDDP] Support multiple groups (#394)

* Adding multiple groups support to ShardedDDP + unit test
* adding gloo to the backends tested for multiple groups
parent ef7146d5
......@@ -470,7 +470,7 @@ class ShardedDataParallel(nn.Module):
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope
......@@ -545,7 +545,7 @@ class ShardedDataParallel(nn.Module):
for param in self._trainable_params:
device = param.device
dst_rank = self._trainable_param_to_rank[param]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
if param.device not in self.buckets.keys():
self.buckets[param.device] = [
......
......@@ -54,6 +54,10 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)
skip_if_less_four_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
)
skip_if_py38 = pytest.mark.skipif(
sys.version_info.major == 3 and sys.version_info.minor == 8, reason="Python3.8 is skipped"
)
......
......@@ -23,7 +23,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import GPT2, check_same_model_params, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu
from fairscale.utils.testing import (
GPT2,
check_same_model_params,
skip_if_less_four_gpu,
skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu,
)
def run_one_step(rank, world_size, backend, device, temp_file_name):
......@@ -591,3 +598,80 @@ def test_gpt2():
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
# Only work with the even ranks, to check that the global_rank indexing is properly used
dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)
sub_group_ranks = [0, 2]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch.manual_seed(rank)
np.random.seed(rank)
# Standard deep learning setup
device = "cuda"
torch.cuda.set_device(rank)
epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
loss_fn = torch.nn.L1Loss().to(device)
def check(optimizer, model):
# Just run a couple of epochs, check that the model is properly updated
for _ in range(epochs):
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
return loss
_ = optimizer.step(closure=closure)
# Check that all the params are the same on all ranks
for pg in optimizer.param_groups:
for p in pg["params"]:
receptacle = [p.clone() for _ in sub_group_ranks]
dist.all_gather(receptacle, p, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks {} - {}".format(
torch.norm(receptacle[0]), torch.norm(sync_p)
)
if rank in sub_group_ranks:
# Model not-fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group)
model = ShardedDataParallel(model, optimizer, process_group=process_group)
check(optimizer, model)
dist.destroy_process_group(process_group)
@skip_if_less_four_gpu
def test_multiple_groups():
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
for backend in ["gloo", "nccl"]:
print("Testing backend ", backend)
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, backend), nprocs=world_size, join=True,
)
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, "gloo"), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment