Unverified Commit 821b3ad6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Fix the BC issue of ddp (#325)

* fix the BC issue of ddp

* minor fix for the docstring
parent 41ca64cb
......@@ -4,8 +4,9 @@ from .data_container import DataContainer
from .data_parallel import MMDataParallel
from .distributed import MMDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs
from .utils import is_parallel_module
__all__ = [
'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
'scatter', 'scatter_kwargs'
'scatter', 'scatter_kwargs', 'is_parallel_module'
]
# Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel
from .distributed_deprecated import MMDistributedDataParallel
def is_parallel_module(module):
"""Check if a module is a parallel module.
The following 3 modules (and their subclasses) are regarded as parallel
modules: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version).
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the input module is a parallel module.
"""
parallels = (DataParallel, DistributedDataParallel,
MMDistributedDataParallel)
if isinstance(module, parallels):
return True
else:
return False
......@@ -7,6 +7,7 @@ from abc import ABCMeta, abstractmethod
import torch
import mmcv
from ..parallel import is_parallel_module
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook
......@@ -56,7 +57,11 @@ class BaseRunner(metaclass=ABCMeta):
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if hasattr(model, 'train_step') or hasattr(model, 'val_step'):
if is_parallel_module(model):
_model = model.module
else:
_model = model
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
......
from unittest.mock import MagicMock, patch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
is_parallel_module)
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP
@patch('torch.distributed._broadcast_coalesced', MagicMock)
@patch('torch.distributed.broadcast', MagicMock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
def test_is_parallel_module():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
assert not is_parallel_module(model)
dp = DataParallel(model)
assert is_parallel_module(dp)
mmdp = MMDataParallel(model)
assert is_parallel_module(mmdp)
ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(ddp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_parallel_module(mmddp)
deprecated_mmddp = DeprecatedMMDDP(model)
assert is_parallel_module(deprecated_mmddp)
......@@ -10,6 +10,7 @@ import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel
from mmcv.runner import EpochBasedRunner
......@@ -77,6 +78,24 @@ def test_epoch_based_runner():
os.removedirs(work_dir)
def test_runner_with_parallel():
def batch_processor():
pass
model = MMDataParallel(OldStyleModel())
_ = EpochBasedRunner(model, batch_processor)
with pytest.raises(RuntimeError):
# batch_processor and train_step() cannot be both set
def batch_processor():
pass
model = MMDataParallel(Model())
_ = EpochBasedRunner(model, batch_processor)
def test_save_checkpoint():
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment