Unverified Commit 79942100 authored by Harry's avatar Harry Committed by GitHub
Browse files

Add module wrapper registry (#352)

* feat: modify parallel to module wrapper

* fix: unittest

* fix: add registry file

* feat: add module wrapper unittest
parent c6784f4a
...@@ -3,10 +3,11 @@ from .collate import collate ...@@ -3,10 +3,11 @@ from .collate import collate
from .data_container import DataContainer from .data_container import DataContainer
from .data_parallel import MMDataParallel from .data_parallel import MMDataParallel
from .distributed import MMDistributedDataParallel from .distributed import MMDistributedDataParallel
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter, scatter_kwargs from .scatter_gather import scatter, scatter_kwargs
from .utils import is_parallel_module from .utils import is_module_wrapper
__all__ = [ __all__ = [
'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel', 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
'scatter', 'scatter_kwargs', 'is_parallel_module' 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
] ]
...@@ -6,9 +6,11 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, ...@@ -6,9 +6,11 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module): class MMDistributedDataParallel(nn.Module):
def __init__(self, def __init__(self,
......
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.utils import Registry
MODULE_WRAPPERS = Registry('module wrapper')
MODULE_WRAPPERS.register_module(DataParallel)
MODULE_WRAPPERS.register_module(DistributedDataParallel)
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel from .registry import MODULE_WRAPPERS
from .distributed_deprecated import MMDistributedDataParallel
def is_module_wrapper(module):
"""Check if a module is a module wrapper.
def is_parallel_module(module): The following 3 modules in MMCV (and their subclasses) are regarded as
"""Check if a module is a parallel module. module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
The following 3 modules (and their subclasses) are regarded as parallel module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
modules: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version).
Args: Args:
module (nn.Module): The module to be checked. module (nn.Module): The module to be checked.
Returns: Returns:
bool: True if the input module is a parallel module. bool: True if the input module is a module wrapper.
""" """
parallels = (DataParallel, DistributedDataParallel, module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
MMDistributedDataParallel) return isinstance(module, module_wrappers)
return isinstance(module, parallels)
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
import mmcv import mmcv
from ..parallel import is_parallel_module from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook from .hooks import HOOKS, Hook, IterTimerHook
...@@ -60,7 +60,7 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -60,7 +60,7 @@ class BaseRunner(metaclass=ABCMeta):
'train_step() and val_step() in the model instead.') 'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and # raise an error is `batch_processor` is not None and
# `model.train_step()` exists. # `model.train_step()` exists.
if is_parallel_module(model): if is_module_wrapper(model):
_model = model.module _model = model.module
else: else:
_model = model _model = model
......
...@@ -14,7 +14,7 @@ from torch.utils import model_zoo ...@@ -14,7 +14,7 @@ from torch.utils import model_zoo
import mmcv import mmcv
from ..fileio import load as load_file from ..fileio import load as load_file
from ..parallel import is_parallel_module from ..parallel import is_module_wrapper
from ..utils import mkdir_or_exist from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
...@@ -63,7 +63,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -63,7 +63,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load(module, prefix=''): def load(module, prefix=''):
# recursively check parallel module in case that the model has a # recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP)) # complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_parallel_module(module): if is_module_wrapper(module):
module = module.module module = module.module
local_metadata = {} if metadata is None else metadata.get( local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {}) prefix[:-1], {})
...@@ -273,7 +273,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -273,7 +273,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename)) mmcv.mkdir_or_exist(osp.dirname(filename))
if is_parallel_module(model): if is_module_wrapper(model):
model = model.module model = model.module
checkpoint = { checkpoint = {
......
...@@ -3,8 +3,8 @@ from unittest.mock import MagicMock, patch ...@@ -3,8 +3,8 @@ from unittest.mock import MagicMock, patch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel, from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
is_parallel_module) MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel.distributed_deprecated import \ from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP MMDistributedDataParallel as DeprecatedMMDDP
...@@ -12,7 +12,7 @@ from mmcv.parallel.distributed_deprecated import \ ...@@ -12,7 +12,7 @@ from mmcv.parallel.distributed_deprecated import \
@patch('torch.distributed._broadcast_coalesced', MagicMock) @patch('torch.distributed._broadcast_coalesced', MagicMock)
@patch('torch.distributed.broadcast', MagicMock) @patch('torch.distributed.broadcast', MagicMock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock) @patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
def test_is_parallel_module(): def test_is_module_wrapper():
class Model(nn.Module): class Model(nn.Module):
...@@ -24,19 +24,32 @@ def test_is_parallel_module(): ...@@ -24,19 +24,32 @@ def test_is_parallel_module():
return self.conv(x) return self.conv(x)
model = Model() model = Model()
assert not is_parallel_module(model) assert not is_module_wrapper(model)
dp = DataParallel(model) dp = DataParallel(model)
assert is_parallel_module(dp) assert is_module_wrapper(dp)
mmdp = MMDataParallel(model) mmdp = MMDataParallel(model)
assert is_parallel_module(mmdp) assert is_module_wrapper(mmdp)
ddp = DistributedDataParallel(model, process_group=MagicMock()) ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(ddp) assert is_module_wrapper(ddp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(mmddp) assert is_module_wrapper(mmddp)
deprecated_mmddp = DeprecatedMMDDP(model) deprecated_mmddp = DeprecatedMMDDP(model)
assert is_parallel_module(deprecated_mmddp) assert is_module_wrapper(deprecated_mmddp)
# test module wrapper registry
@MODULE_WRAPPERS.register_module()
class ModuleWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment