Unverified Commit 79942100 authored by Harry's avatar Harry Committed by GitHub
Browse files

Add module wrapper registry (#352)

* feat: modify parallel to module wrapper

* fix: unittest

* fix: add registry file

* feat: add module wrapper unittest
parent c6784f4a
......@@ -3,10 +3,11 @@ from .collate import collate
from .data_container import DataContainer
from .data_parallel import MMDataParallel
from .distributed import MMDistributedDataParallel
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter, scatter_kwargs
from .utils import is_parallel_module
from .utils import is_module_wrapper
__all__ = [
'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
'scatter', 'scatter_kwargs', 'is_parallel_module'
'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
]
......@@ -6,9 +6,11 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module):
def __init__(self,
......
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.utils import Registry
MODULE_WRAPPERS = Registry('module wrapper')
MODULE_WRAPPERS.register_module(DataParallel)
MODULE_WRAPPERS.register_module(DistributedDataParallel)
# Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .registry import MODULE_WRAPPERS
from .distributed_deprecated import MMDistributedDataParallel
def is_module_wrapper(module):
"""Check if a module is a module wrapper.
def is_parallel_module(module):
"""Check if a module is a parallel module.
The following 3 modules (and their subclasses) are regarded as parallel
modules: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version).
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the input module is a parallel module.
bool: True if the input module is a module wrapper.
"""
parallels = (DataParallel, DistributedDataParallel,
MMDistributedDataParallel)
return isinstance(module, parallels)
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
return isinstance(module, module_wrappers)
......@@ -8,7 +8,7 @@ import torch
from torch.optim import Optimizer
import mmcv
from ..parallel import is_parallel_module
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook
......@@ -60,7 +60,7 @@ class BaseRunner(metaclass=ABCMeta):
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_parallel_module(model):
if is_module_wrapper(model):
_model = model.module
else:
_model = model
......
......@@ -14,7 +14,7 @@ from torch.utils import model_zoo
import mmcv
from ..fileio import load as load_file
from ..parallel import is_parallel_module
from ..parallel import is_module_wrapper
from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info
......@@ -63,7 +63,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_parallel_module(module):
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
......@@ -273,7 +273,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if is_parallel_module(model):
if is_module_wrapper(model):
model = model.module
checkpoint = {
......
......@@ -3,8 +3,8 @@ from unittest.mock import MagicMock, patch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
is_parallel_module)
from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP
......@@ -12,7 +12,7 @@ from mmcv.parallel.distributed_deprecated import \
@patch('torch.distributed._broadcast_coalesced', MagicMock)
@patch('torch.distributed.broadcast', MagicMock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
def test_is_parallel_module():
def test_is_module_wrapper():
class Model(nn.Module):
......@@ -24,19 +24,32 @@ def test_is_parallel_module():
return self.conv(x)
model = Model()
assert not is_parallel_module(model)
assert not is_module_wrapper(model)
dp = DataParallel(model)
assert is_parallel_module(dp)
assert is_module_wrapper(dp)
mmdp = MMDataParallel(model)
assert is_parallel_module(mmdp)
assert is_module_wrapper(mmdp)
ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(ddp)
assert is_module_wrapper(ddp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(mmddp)
assert is_module_wrapper(mmddp)
deprecated_mmddp = DeprecatedMMDDP(model)
assert is_parallel_module(deprecated_mmddp)
assert is_module_wrapper(deprecated_mmddp)
# test module wrapper registry
@MODULE_WRAPPERS.register_module()
class ModuleWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment