builder.py 4.21 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3

4
from mmcv.cnn import MODELS as MMCV_MODELS
5
from mmcv.utils import Registry
6

7
8
9
10
11
12
13
14
from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
from mmdet.models.builder import DETECTORS as MMDET_DETECTORS
from mmdet.models.builder import HEADS as MMDET_HEADS
from mmdet.models.builder import LOSSES as MMDET_LOSSES
from mmdet.models.builder import NECKS as MMDET_NECKS
from mmdet.models.builder import ROI_EXTRACTORS as MMDET_ROI_EXTRACTORS
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
from mmseg.models.builder import LOSSES as MMSEG_LOSSES
15

16
17
MODELS = Registry('models', parent=MMCV_MODELS)

18
19
20
21
22
23
24
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
25
26
27
VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS
28
SEGMENTORS = MODELS
zhangwenwei's avatar
zhangwenwei committed
29
30
31


def build_backbone(cfg):
32
    """Build backbone."""
33
34
35
36
    if cfg['type'] in BACKBONES._module_dict.keys():
        return BACKBONES.build(cfg)
    else:
        return MMDET_BACKBONES.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
37
38
39


def build_neck(cfg):
40
    """Build neck."""
41
42
43
44
    if cfg['type'] in NECKS._module_dict.keys():
        return NECKS.build(cfg)
    else:
        return MMDET_NECKS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
45
46
47


def build_roi_extractor(cfg):
48
    """Build RoI feature extractor."""
49
    if cfg['type'] in ROI_EXTRACTORS._module_dict.keys():
50
51
52
        return ROI_EXTRACTORS.build(cfg)
    else:
        return MMDET_ROI_EXTRACTORS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
53
54
55


def build_shared_head(cfg):
56
    """Build shared head of detector."""
57
58
59
60
    if cfg['type'] in SHARED_HEADS._module_dict.keys():
        return SHARED_HEADS.build(cfg)
    else:
        return MMDET_SHARED_HEADS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
61
62
63


def build_head(cfg):
64
    """Build head."""
65
66
67
68
    if cfg['type'] in HEADS._module_dict.keys():
        return HEADS.build(cfg)
    else:
        return MMDET_HEADS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
69
70
71


def build_loss(cfg):
72
    """Build loss function."""
73
74
75
76
77
78
    if cfg['type'] in LOSSES._module_dict.keys():
        return LOSSES.build(cfg)
    elif cfg['type'] in MMDET_LOSSES._module_dict.keys():
        return MMDET_LOSSES.build(cfg)
    else:
        return MMSEG_LOSSES.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
79
80
81


def build_detector(cfg, train_cfg=None, test_cfg=None):
82
    """Build detector."""
83
84
85
86
87
88
89
90
    if train_cfg is not None or test_cfg is not None:
        warnings.warn(
            'train_cfg and test_cfg is deprecated, '
            'please specify them in model', UserWarning)
    assert cfg.get('train_cfg') is None or train_cfg is None, \
        'train_cfg specified in both outer field and model field '
    assert cfg.get('test_cfg') is None or test_cfg is None, \
        'test_cfg specified in both outer field and model field '
91
92
93
94
95
96
    if cfg['type'] in DETECTORS._module_dict.keys():
        return DETECTORS.build(
            cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
    else:
        return MMDET_DETECTORS.build(
            cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
zhangwenwei's avatar
zhangwenwei committed
97
98


99
100
101
102
103
104
105
106
107
108
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
    """Build segmentor."""
    if train_cfg is not None or test_cfg is not None:
        warnings.warn(
            'train_cfg and test_cfg is deprecated, '
            'please specify them in model', UserWarning)
    assert cfg.get('train_cfg') is None or train_cfg is None, \
        'train_cfg specified in both outer field and model field '
    assert cfg.get('test_cfg') is None or test_cfg is None, \
        'test_cfg specified in both outer field and model field '
109
110
    return SEGMENTORS.build(
        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
111
112
113
114
115
116
117
118
119
120
121
122
123
124


def build_model(cfg, train_cfg=None, test_cfg=None):
    """A function warpper for building 3D detector or segmentor according to
    cfg.

    Should be deprecated in the future.
    """
    if cfg.type in ['EncoderDecoder3D']:
        return build_segmentor(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
    else:
        return build_detector(cfg, train_cfg=train_cfg, test_cfg=test_cfg)


zhangwenwei's avatar
zhangwenwei committed
125
def build_voxel_encoder(cfg):
126
    """Build voxel encoder."""
127
    return VOXEL_ENCODERS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
128
129
130


def build_middle_encoder(cfg):
131
    """Build middle level encoder."""
132
    return MIDDLE_ENCODERS.build(cfg)
zhangwenwei's avatar
zhangwenwei committed
133
134
135


def build_fusion_layer(cfg):
136
    """Build fusion layer."""
137
    return FUSION_LAYERS.build(cfg)