builder.py 1.36 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
from mmcv.runner import obj_from_dict
Kai Chen's avatar
Kai Chen committed
2
3
4
from torch import nn

from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
5
               mask_heads, single_stage_heads)
Kai Chen's avatar
Kai Chen committed
6
7
8

__all__ = [
    'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
Kai Chen's avatar
Kai Chen committed
9
10
    'build_bbox_head', 'build_mask_head', 'build_single_stage_head',
    'build_detector'
Kai Chen's avatar
Kai Chen committed
11
12
13
]


Kai Chen's avatar
Kai Chen committed
14
def _build_module(cfg, parrent=None, default_args=None):
Kai Chen's avatar
Kai Chen committed
15
    return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
Kai Chen's avatar
Kai Chen committed
16
        cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
17
18


Kai Chen's avatar
Kai Chen committed
19
def build(cfg, parrent=None, default_args=None):
Kai Chen's avatar
Kai Chen committed
20
    if isinstance(cfg, list):
Kai Chen's avatar
Kai Chen committed
21
        modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
Kai Chen's avatar
Kai Chen committed
22
23
        return nn.Sequential(*modules)
    else:
Kai Chen's avatar
Kai Chen committed
24
        return _build_module(cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def build_backbone(cfg):
    return build(cfg, backbones)


def build_neck(cfg):
    return build(cfg, necks)


def build_rpn_head(cfg):
    return build(cfg, rpn_heads)


def build_roi_extractor(cfg):
    return build(cfg, roi_extractors)


def build_bbox_head(cfg):
    return build(cfg, bbox_heads)


def build_mask_head(cfg):
    return build(cfg, mask_heads)
Kai Chen's avatar
Kai Chen committed
49
50


Kai Chen's avatar
Kai Chen committed
51
52
53
54
def build_single_stage_head(cfg):
    return build(cfg, single_stage_heads)


Kai Chen's avatar
Kai Chen committed
55
def build_detector(cfg, train_cfg=None, test_cfg=None):
Kai Chen's avatar
Kai Chen committed
56
    from . import detectors
Kai Chen's avatar
Kai Chen committed
57
    return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))