builder.py 1.24 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
from mmcv.runner import obj_from_dict
Kai Chen's avatar
Kai Chen committed
2
3
4
from torch import nn

from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
Kai Chen's avatar
Kai Chen committed
5
               mask_heads)
Kai Chen's avatar
Kai Chen committed
6
7
8

__all__ = [
    'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
Kai Chen's avatar
Kai Chen committed
9
    'build_bbox_head', 'build_mask_head', 'build_detector'
Kai Chen's avatar
Kai Chen committed
10
11
12
]


Kai Chen's avatar
Kai Chen committed
13
def _build_module(cfg, parrent=None, default_args=None):
Kai Chen's avatar
Kai Chen committed
14
    return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
Kai Chen's avatar
Kai Chen committed
15
        cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
16
17


Kai Chen's avatar
Kai Chen committed
18
def build(cfg, parrent=None, default_args=None):
Kai Chen's avatar
Kai Chen committed
19
    if isinstance(cfg, list):
Kai Chen's avatar
Kai Chen committed
20
        modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
Kai Chen's avatar
Kai Chen committed
21
22
        return nn.Sequential(*modules)
    else:
Kai Chen's avatar
Kai Chen committed
23
        return _build_module(cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def build_backbone(cfg):
    return build(cfg, backbones)


def build_neck(cfg):
    return build(cfg, necks)


def build_rpn_head(cfg):
    return build(cfg, rpn_heads)


def build_roi_extractor(cfg):
    return build(cfg, roi_extractors)


def build_bbox_head(cfg):
    return build(cfg, bbox_heads)


def build_mask_head(cfg):
    return build(cfg, mask_heads)
Kai Chen's avatar
Kai Chen committed
48
49
50


def build_detector(cfg, train_cfg=None, test_cfg=None):
Kai Chen's avatar
Kai Chen committed
51
    from . import detectors
Kai Chen's avatar
Kai Chen committed
52
    return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))