Unverified Commit 5cddaea4 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] check before calling _specify_ddp_gpu_num (#626)



- this function is being removed in pytorch
- we only need to call it in case we are working with older pytorch
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent d3b86d65
...@@ -539,10 +539,11 @@ class ShardedDataParallel(nn.Module): ...@@ -539,10 +539,11 @@ class ShardedDataParallel(nn.Module):
Adapted from ``torch.nn.distributed.DistributedDataParallel``. Adapted from ``torch.nn.distributed.DistributedDataParallel``.
""" """
for layer in module.modules(): for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm): if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules" assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device # device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway' # SyncBatchNorm only supports DDP with single-process single-device anyway'
# This function is removed from pytorch since 1.9.
layer._specify_ddp_gpu_num(1) # type: ignore layer._specify_ddp_gpu_num(1) # type: ignore
def _setup_bucket_strategy(self) -> None: def _setup_bucket_strategy(self) -> None:
......
...@@ -50,8 +50,9 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None: ...@@ -50,8 +50,9 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
is happily running even without DDP. E.g. this is used by FSDP. is happily running even without DDP. E.g. this is used by FSDP.
""" """
for layer in module.modules(): for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm): if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
# Number "1" below meant to be the number of GPUs for each DDP worker. # Number "1" below meant to be the number of GPUs for each DDP worker.
# (i.e. "device_ids" in DDP. As far as I see, the value is not actually # (i.e. "device_ids" in DDP. As far as I see, the value is not actually
# used, but this call needs to be made to avoid an exception. # used, but this call needs to be made to avoid an exception.
# This function is removed from pytorch since 1.9.
layer._specify_ddp_gpu_num(1) # type: ignore layer._specify_ddp_gpu_num(1) # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment