builder.py 2.76 KB
Newer Older
1
import warnings
2
from mmcv.utils import Registry
3

4
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
5
                                  ROI_EXTRACTORS, SHARED_HEADS, build)
6
from mmseg.models.builder import SEGMENTORS
7

8
9
10
VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
FUSION_LAYERS = Registry('fusion_layer')
zhangwenwei's avatar
zhangwenwei committed
11
12
13


def build_backbone(cfg):
14
    """Build backbone."""
15
    return build(cfg, BACKBONES)
zhangwenwei's avatar
zhangwenwei committed
16
17
18


def build_neck(cfg):
19
    """Build neck."""
20
    return build(cfg, NECKS)
zhangwenwei's avatar
zhangwenwei committed
21
22
23


def build_roi_extractor(cfg):
24
    """Build RoI feature extractor."""
25
    return build(cfg, ROI_EXTRACTORS)
zhangwenwei's avatar
zhangwenwei committed
26
27
28


def build_shared_head(cfg):
29
    """Build shared head of detector."""
30
    return build(cfg, SHARED_HEADS)
zhangwenwei's avatar
zhangwenwei committed
31
32
33


def build_head(cfg):
34
    """Build head."""
35
    return build(cfg, HEADS)
zhangwenwei's avatar
zhangwenwei committed
36
37
38


def build_loss(cfg):
39
    """Build loss function."""
40
    return build(cfg, LOSSES)
zhangwenwei's avatar
zhangwenwei committed
41
42
43


def build_detector(cfg, train_cfg=None, test_cfg=None):
44
    """Build detector."""
45
46
47
48
49
50
51
52
    if train_cfg is not None or test_cfg is not None:
        warnings.warn(
            'train_cfg and test_cfg is deprecated, '
            'please specify them in model', UserWarning)
    assert cfg.get('train_cfg') is None or train_cfg is None, \
        'train_cfg specified in both outer field and model field '
    assert cfg.get('test_cfg') is None or test_cfg is None, \
        'test_cfg specified in both outer field and model field '
53
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
zhangwenwei's avatar
zhangwenwei committed
54
55


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
    """Build segmentor."""
    if train_cfg is not None or test_cfg is not None:
        warnings.warn(
            'train_cfg and test_cfg is deprecated, '
            'please specify them in model', UserWarning)
    assert cfg.get('train_cfg') is None or train_cfg is None, \
        'train_cfg specified in both outer field and model field '
    assert cfg.get('test_cfg') is None or test_cfg is None, \
        'test_cfg specified in both outer field and model field '
    return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))


def build_model(cfg, train_cfg=None, test_cfg=None):
    """A function warpper for building 3D detector or segmentor according to
    cfg.

    Should be deprecated in the future.
    """
    if cfg.type in ['EncoderDecoder3D']:
        return build_segmentor(cfg, train_cfg=train_cfg, test_cfg=test_cfg)
    else:
        return build_detector(cfg, train_cfg=train_cfg, test_cfg=test_cfg)


zhangwenwei's avatar
zhangwenwei committed
81
def build_voxel_encoder(cfg):
82
    """Build voxel encoder."""
83
    return build(cfg, VOXEL_ENCODERS)
zhangwenwei's avatar
zhangwenwei committed
84
85
86


def build_middle_encoder(cfg):
87
    """Build middle level encoder."""
88
    return build(cfg, MIDDLE_ENCODERS)
zhangwenwei's avatar
zhangwenwei committed
89
90
91


def build_fusion_layer(cfg):
92
    """Build fusion layer."""
93
    return build(cfg, FUSION_LAYERS)