Unverified Commit 84a3bdbe authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Typo in ShardedDDP unit test (#282)

* fix typo, backend for CPU test
parent 1c8d219d
...@@ -129,9 +129,9 @@ class ShardedDataParallel(nn.Module): ...@@ -129,9 +129,9 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def reduce(self) -> None: def reduce(self) -> None:
""" .. deprecated:: 0.0.4 """.. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass This does not need to be called, the gradient reduction is done automatically during the BW pass
""" """
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass") logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
...@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module): ...@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
self.should_accumulate_grads = old_should_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
def _clear_counters(self) -> None: def _clear_counters(self) -> None:
""" Reset all the grad reduce and call counters """Reset all the grad reduce and call counters"""
"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers} self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
...@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module): ...@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
_ = list(map(lambda x: x.wait(), work_handles)) _ = list(map(lambda x: x.wait(), work_handles))
def _passing_sync_batchnorm_handle(self, module): def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
""" """
Passes handle required for ``torch.nn.modules.SyncBatchNorm``. Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``. Adapted from ``torch.nn.distributed.DistributedDataParallel``.
""" """
for layer in module.modules(): for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm): if isinstance(layer, torch.nn.modules.SyncBatchNorm):
assert self.device_type != 'cpu', "SyncBatchNorm layers only work with GPU modules" assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device # device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway' # SyncBatchNorm only supports DDP with single-process single-device anyway'
layer._specify_ddp_gpu_num(1) layer._specify_ddp_gpu_num(1) # type: ignore
...@@ -316,8 +316,7 @@ def test_ddp_attributes(): ...@@ -316,8 +316,7 @@ def test_ddp_attributes():
# - device_type # - device_type
url = "file://" + tempfile.mkstemp()[1] url = "file://" + tempfile.mkstemp()[1]
backend = dist.Backend.NCCL dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
dist.init_process_group(init_method=url, backend=backend, rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
...@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm(): ...@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
device = "cuda" device = "cuda"
mp.spawn( mp.spawn(
run_test_ddp_sync_batch_norm, run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
args=(world_size, backend, device, temp_file_name),
nprocs=world_size,
join=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment