Unverified Commit 84a3bdbe authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Typo in ShardedDDP unit test (#282)

* fix typo, backend for CPU test
parent 1c8d219d
......@@ -129,7 +129,7 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs)
def reduce(self) -> None:
""" .. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
......@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
self.should_accumulate_grads = old_should_accumulate_grads
def _clear_counters(self) -> None:
""" Reset all the grad reduce and call counters
"""
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
......@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
_ = list(map(lambda x: x.wait(), work_handles))
def _passing_sync_batchnorm_handle(self, module):
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
"""
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
"""
for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
assert self.device_type != 'cpu', "SyncBatchNorm layers only work with GPU modules"
assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer._specify_ddp_gpu_num(1)
layer._specify_ddp_gpu_num(1) # type: ignore
......@@ -316,8 +316,7 @@ def test_ddp_attributes():
# - device_type
url = "file://" + tempfile.mkstemp()[1]
backend = dist.Backend.NCCL
dist.init_process_group(init_method=url, backend=backend, rank=0, world_size=1)
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
......@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(
run_test_ddp_sync_batch_norm,
args=(world_size, backend, device, temp_file_name),
nprocs=world_size,
join=True
run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment