res_layer.py 2.21 KB
Newer Older
myownskyW7's avatar
myownskyW7 committed
1
2
3
4
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint

Cao Yuhang's avatar
Cao Yuhang committed
5
from mmdet.core import auto_fp16
Kai Chen's avatar
Kai Chen committed
6
from mmdet.utils import get_root_logger
myownskyW7's avatar
myownskyW7 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
from ..backbones import ResNet, make_res_layer
from ..registry import SHARED_HEADS


@SHARED_HEADS.register_module
class ResLayer(nn.Module):

    def __init__(self,
                 depth,
                 stage=3,
                 stride=2,
                 dilation=1,
                 style='pytorch',
Kai Chen's avatar
Kai Chen committed
20
                 norm_cfg=dict(type='BN', requires_grad=True),
myownskyW7's avatar
myownskyW7 committed
21
22
23
24
25
                 norm_eval=True,
                 with_cp=False,
                 dcn=None):
        super(ResLayer, self).__init__()
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
26
        self.norm_cfg = norm_cfg
myownskyW7's avatar
myownskyW7 committed
27
        self.stage = stage
Cao Yuhang's avatar
Cao Yuhang committed
28
        self.fp16_enabled = False
myownskyW7's avatar
myownskyW7 committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        block, stage_blocks = ResNet.arch_settings[depth]
        stage_block = stage_blocks[stage]
        planes = 64 * 2**stage
        inplanes = 64 * 2**(stage - 1) * block.expansion

        res_layer = make_res_layer(
            block,
            inplanes,
            planes,
            stage_block,
            stride=stride,
            dilation=dilation,
            style=style,
            with_cp=with_cp,
Kai Chen's avatar
Kai Chen committed
43
            norm_cfg=self.norm_cfg,
myownskyW7's avatar
myownskyW7 committed
44
45
46
47
48
            dcn=dcn)
        self.add_module('layer{}'.format(stage + 1), res_layer)

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
49
            logger = get_root_logger()
myownskyW7's avatar
myownskyW7 committed
50
51
52
53
54
55
56
57
58
59
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
        else:
            raise TypeError('pretrained must be a str or None')

Cao Yuhang's avatar
Cao Yuhang committed
60
    @auto_fp16()
myownskyW7's avatar
myownskyW7 committed
61
62
63
64
65
66
67
68
69
70
71
    def forward(self, x):
        res_layer = getattr(self, 'layer{}'.format(self.stage + 1))
        out = res_layer(x)
        return out

    def train(self, mode=True):
        super(ResLayer, self).train(mode)
        if self.norm_eval:
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()