Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
082dabfd
Unverified
Commit
082dabfd
authored
Mar 26, 2022
by
Jiazhen Wang
Committed by
GitHub
Mar 26, 2022
Browse files
[Fix] Fix _sync_params was removed in torch1.11.0 (#1816)
* fix pt111 dist * fix val step
parent
1a2f174f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
8 deletions
+34
-8
mmcv/parallel/distributed.py
mmcv/parallel/distributed.py
+34
-8
No files found.
mmcv/parallel/distributed.py
View file @
082dabfd
...
...
@@ -44,8 +44,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
self
.
_sync_params
()
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
...
...
@@ -57,8 +64,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
output
=
self
.
module
.
train_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
...
...
@@ -86,8 +99,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
self
.
_sync_params
()
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
...
...
@@ -99,8 +119,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
output
=
self
.
module
.
val_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment