builder.py 1015 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
import mmcv
pangjm's avatar
pangjm committed
2
from mmcv import torchpack
Kai Chen's avatar
Kai Chen committed
3
4
5
6
7
8
9
10
11
12
13
14
from torch import nn

from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
               mask_heads)

__all__ = [
    'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
    'build_bbox_head', 'build_mask_head'
]


def _build_module(cfg, parrent=None):
pangjm's avatar
pangjm committed
15
    return cfg if isinstance(cfg, nn.Module) else torchpack.obj_from_dict(
Kai Chen's avatar
Kai Chen committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        cfg, parrent)


def build(cfg, parrent=None):
    if isinstance(cfg, list):
        modules = [_build_module(cfg_, parrent) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return _build_module(cfg, parrent)


def build_backbone(cfg):
    return build(cfg, backbones)


def build_neck(cfg):
    return build(cfg, necks)


def build_rpn_head(cfg):
    return build(cfg, rpn_heads)


def build_roi_extractor(cfg):
    return build(cfg, roi_extractors)


def build_bbox_head(cfg):
    return build(cfg, bbox_heads)


def build_mask_head(cfg):
    return build(cfg, mask_heads)