Unverified Commit 32bf3499 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Bug] Fix ddp bug when pytorch>=1.7 (#769)



* fix ddp bug when pytorch>=1.7

* fix ddp bug when pytorch>=1.7

* support pat

* fix docs

* use print_log instead of get_logger
Co-authored-by: default avatarnbei <631557085@qq.com>
parent 080474b9
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from torch.nn.parallel.distributed import (DistributedDataParallel, from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors) _find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -28,6 +29,15 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -28,6 +29,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
``self.module.forward()`` with ``self.module.train_step()``. ``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5. It is compatible with PyTorch 1.1 - 1.5.
""" """
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True): if getattr(self, 'require_forward_param_sync', True):
self._sync_params() self._sync_params()
if self.device_ids: if self.device_ids:
...@@ -60,6 +70,14 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -60,6 +70,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
``self.module.forward()`` with ``self.module.val_step()``. ``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5. It is compatible with PyTorch 1.1 - 1.5.
""" """
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True): if getattr(self, 'require_forward_param_sync', True):
self._sync_params() self._sync_params()
if self.device_ids: if self.device_ids:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment