Unverified Commit c90f2be0 authored by whcao's avatar whcao Committed by GitHub
Browse files

[Fix] Fix is_module_wrapper (#1900)

* fix is_module_wrapper

* test is_module_wrapper

* fix code style
parent e9f48a4f
...@@ -8,7 +8,8 @@ def is_module_wrapper(module): ...@@ -8,7 +8,8 @@ def is_module_wrapper(module):
The following 3 modules in MMCV (and their subclasses) are regarded as The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel, module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS. module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args: Args:
module (nn.Module): The module to be checked. module (nn.Module): The module to be checked.
...@@ -16,5 +17,14 @@ def is_module_wrapper(module): ...@@ -16,5 +17,14 @@ def is_module_wrapper(module):
Returns: Returns:
bool: True if the input module is a module wrapper. bool: True if the input module is a module wrapper.
""" """
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
return isinstance(module, module_wrappers) def is_module_in_wrapper(module, module_wrapper):
module_wrappers = tuple(module_wrapper.module_dict.values())
if isinstance(module, module_wrappers):
return True
for child in module_wrapper.children.values():
if is_module_in_wrapper(module, child):
return True
return False
return is_module_in_wrapper(module, MODULE_WRAPPERS)
...@@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel, ...@@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
from mmcv.parallel._functions import Scatter, get_input_device, scatter from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \ from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP MMDistributedDataParallel as DeprecatedMMDDP
from mmcv.utils import Registry
def mock(*args, **kwargs): def mock(*args, **kwargs):
...@@ -74,6 +75,36 @@ def test_is_module_wrapper(): ...@@ -74,6 +75,36 @@ def test_is_module_wrapper():
module_wraper = ModuleWrapper(model) module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper) assert is_module_wrapper(module_wraper)
# test module wrapper registry in downstream repo
MMRAZOR_MODULE_WRAPPERS = Registry(
'mmrazor module wrapper', parent=MODULE_WRAPPERS, scope='mmrazor')
MMPOSE_MODULE_WRAPPERS = Registry(
'mmpose module wrapper', parent=MODULE_WRAPPERS, scope='mmpose')
@MMRAZOR_MODULE_WRAPPERS.register_module()
class ModuleWrapperInRazor(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
@MMPOSE_MODULE_WRAPPERS.register_module()
class ModuleWrapperInPose(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
wrapped_module = ModuleWrapperInRazor(model)
assert is_module_wrapper(wrapped_module)
wrapped_module = ModuleWrapperInPose(model)
assert is_module_wrapper(wrapped_module)
def test_get_input_device(): def test_get_input_device():
# if the device is CPU, return -1 # if the device is CPU, return -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment