builder.py 1.22 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
from mmcv import torchpack as tp
Kai Chen's avatar
Kai Chen committed
2
3
4
from torch import nn

from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
Kai Chen's avatar
Kai Chen committed
5
               mask_heads, detectors)
Kai Chen's avatar
Kai Chen committed
6
7
8

__all__ = [
    'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
Kai Chen's avatar
Kai Chen committed
9
    'build_bbox_head', 'build_mask_head', 'build_detector'
Kai Chen's avatar
Kai Chen committed
10
11
12
]


Kai Chen's avatar
Kai Chen committed
13
14
15
def _build_module(cfg, parrent=None, default_args=None):
    return cfg if isinstance(cfg, nn.Module) else tp.obj_from_dict(
        cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
16
17


Kai Chen's avatar
Kai Chen committed
18
def build(cfg, parrent=None, default_args=None):
Kai Chen's avatar
Kai Chen committed
19
    if isinstance(cfg, list):
Kai Chen's avatar
Kai Chen committed
20
        modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
Kai Chen's avatar
Kai Chen committed
21
22
        return nn.Sequential(*modules)
    else:
Kai Chen's avatar
Kai Chen committed
23
        return _build_module(cfg, parrent, default_args)
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def build_backbone(cfg):
    return build(cfg, backbones)


def build_neck(cfg):
    return build(cfg, necks)


def build_rpn_head(cfg):
    return build(cfg, rpn_heads)


def build_roi_extractor(cfg):
    return build(cfg, roi_extractors)


def build_bbox_head(cfg):
    return build(cfg, bbox_heads)


def build_mask_head(cfg):
    return build(cfg, mask_heads)
Kai Chen's avatar
Kai Chen committed
48
49
50
51


def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))