Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
c90f2be0
Unverified
Commit
c90f2be0
authored
May 25, 2022
by
whcao
Committed by
GitHub
May 25, 2022
Browse files
[Fix] Fix is_module_wrapper (#1900)
* fix is_module_wrapper * test is_module_wrapper * fix code style
parent
e9f48a4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
3 deletions
+44
-3
mmcv/parallel/utils.py
mmcv/parallel/utils.py
+13
-3
tests/test_parallel.py
tests/test_parallel.py
+31
-0
No files found.
mmcv/parallel/utils.py
View file @
c90f2be0
...
@@ -8,7 +8,8 @@ def is_module_wrapper(module):
...
@@ -8,7 +8,8 @@ def is_module_wrapper(module):
The following 3 modules in MMCV (and their subclasses) are regarded as
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
Args:
module (nn.Module): The module to be checked.
module (nn.Module): The module to be checked.
...
@@ -16,5 +17,14 @@ def is_module_wrapper(module):
...
@@ -16,5 +17,14 @@ def is_module_wrapper(module):
Returns:
Returns:
bool: True if the input module is a module wrapper.
bool: True if the input module is a module wrapper.
"""
"""
module_wrappers
=
tuple
(
MODULE_WRAPPERS
.
module_dict
.
values
())
return
isinstance
(
module
,
module_wrappers
)
def
is_module_in_wrapper
(
module
,
module_wrapper
):
module_wrappers
=
tuple
(
module_wrapper
.
module_dict
.
values
())
if
isinstance
(
module
,
module_wrappers
):
return
True
for
child
in
module_wrapper
.
children
.
values
():
if
is_module_in_wrapper
(
module
,
child
):
return
True
return
False
return
is_module_in_wrapper
(
module
,
MODULE_WRAPPERS
)
tests/test_parallel.py
View file @
c90f2be0
...
@@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
...
@@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
from
mmcv.parallel._functions
import
Scatter
,
get_input_device
,
scatter
from
mmcv.parallel._functions
import
Scatter
,
get_input_device
,
scatter
from
mmcv.parallel.distributed_deprecated
import
\
from
mmcv.parallel.distributed_deprecated
import
\
MMDistributedDataParallel
as
DeprecatedMMDDP
MMDistributedDataParallel
as
DeprecatedMMDDP
from
mmcv.utils
import
Registry
def
mock
(
*
args
,
**
kwargs
):
def
mock
(
*
args
,
**
kwargs
):
...
@@ -74,6 +75,36 @@ def test_is_module_wrapper():
...
@@ -74,6 +75,36 @@ def test_is_module_wrapper():
module_wraper
=
ModuleWrapper
(
model
)
module_wraper
=
ModuleWrapper
(
model
)
assert
is_module_wrapper
(
module_wraper
)
assert
is_module_wrapper
(
module_wraper
)
# test module wrapper registry in downstream repo
MMRAZOR_MODULE_WRAPPERS
=
Registry
(
'mmrazor module wrapper'
,
parent
=
MODULE_WRAPPERS
,
scope
=
'mmrazor'
)
MMPOSE_MODULE_WRAPPERS
=
Registry
(
'mmpose module wrapper'
,
parent
=
MODULE_WRAPPERS
,
scope
=
'mmpose'
)
@
MMRAZOR_MODULE_WRAPPERS
.
register_module
()
class
ModuleWrapperInRazor
(
object
):
def
__init__
(
self
,
module
):
self
.
module
=
module
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
@
MMPOSE_MODULE_WRAPPERS
.
register_module
()
class
ModuleWrapperInPose
(
object
):
def
__init__
(
self
,
module
):
self
.
module
=
module
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
wrapped_module
=
ModuleWrapperInRazor
(
model
)
assert
is_module_wrapper
(
wrapped_module
)
wrapped_module
=
ModuleWrapperInPose
(
model
)
assert
is_module_wrapper
(
wrapped_module
)
def
test_get_input_device
():
def
test_get_input_device
():
# if the device is CPU, return -1
# if the device is CPU, return -1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment