Unverified Commit 1c8d219d authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

[feat] Add Torch Sync Batchnorm handle in sharded DDP (#265)

* Add function to add handle for sync BN
* Add test to ensure batch norm handles have been added
parent fc1a40e1
......@@ -105,6 +105,9 @@ class ShardedDataParallel(nn.Module):
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module)
# Make sure that all ranks start with the same model
if sync_models_at_startup:
self._sync_params_and_buffers()
......@@ -250,3 +253,15 @@ class ShardedDataParallel(nn.Module):
]
_ = list(map(lambda x: x.wait(), work_handles))
def _passing_sync_batchnorm_handle(self, module):
"""
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
"""
for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
assert self.device_type != 'cpu', "SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer._specify_ddp_gpu_num(1)
......@@ -316,7 +316,8 @@ def test_ddp_attributes():
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
backend = dist.Backend.NCCL
dist.init_process_group(init_method=url, backend=backend, rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
......@@ -327,6 +328,37 @@ def test_ddp_attributes():
dist.destroy_process_group()
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert isinstance(model[1], torch.nn.SyncBatchNorm)
# Ensures sync batch norm handles have been added
ddp_model(torch.randn(2, 2).to(device))
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_ddp_sync_batch_norm,
args=(world_size, backend, device, temp_file_name),
nprocs=world_size,
join=True
)
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment