utils.py 735 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) Open-MMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel

from .distributed_deprecated import MMDistributedDataParallel


def is_parallel_module(module):
    """Check if a module is a parallel module.

    The following 3 modules (and their subclasses) are regarded as parallel
    modules: DataParallel, DistributedDataParallel,
    MMDistributedDataParallel (the deprecated version).

    Args:
        module (nn.Module): The module to be checked.

    Returns:
        bool: True if the input module is a parallel module.
    """
    parallels = (DataParallel, DistributedDataParallel,
                 MMDistributedDataParallel)
lizz's avatar
lizz committed
22
    return isinstance(module, parallels)