Unverified Commit 082dabfd authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix] Fix _sync_params was removed in torch1.11.0 (#1816)

* fix pt111 dist

* fix val step
parent 1a2f174f
......@@ -44,8 +44,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
......@@ -57,8 +64,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
......@@ -86,8 +99,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
......@@ -99,8 +119,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment