second.py 2.63 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
from functools import partial

import torch.nn as nn
4
from mmcv.cnn import build_norm_layer
zhangwenwei's avatar
zhangwenwei committed
5
6
from mmcv.runner import load_checkpoint

zhangwenwei's avatar
zhangwenwei committed
7
from mmdet.models import BACKBONES
zhangwenwei's avatar
zhangwenwei committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


class Empty(nn.Module):

    def __init__(self, *args, **kwargs):
        super(Empty, self).__init__()

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            return args[0]
        elif len(args) == 0:
            return None
        return args


23
@BACKBONES.register_module()
zhangwenwei's avatar
zhangwenwei committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class SECOND(nn.Module):
    """Compare with RPN, RPNV2 support arbitrary number of stage.
    """

    def __init__(self,
                 in_channels=128,
                 layer_nums=[3, 5, 5],
                 layer_strides=[2, 2, 2],
                 num_filters=[128, 128, 256],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)):
        super(SECOND, self).__init__()
        assert len(layer_strides) == len(layer_nums)
        assert len(num_filters) == len(layer_nums)

        if norm_cfg is not None:
            Conv2d = partial(nn.Conv2d, bias=False)
        else:
            Conv2d = partial(nn.Conv2d, bias=True)

        in_filters = [in_channels, *num_filters[:-1]]
        # note that when stride > 1, conv2d with same padding isn't
        # equal to pad-conv2d. we should use pad-conv2d.
        blocks = []

        for i, layer_num in enumerate(layer_nums):
            norm_layer = (
                build_norm_layer(norm_cfg, num_filters[i])[1]
                if norm_cfg is not None else Empty)
            block = [
                nn.ZeroPad2d(1),
                Conv2d(
                    in_filters[i], num_filters[i], 3, stride=layer_strides[i]),
                norm_layer,
                nn.ReLU(inplace=True),
            ]
            for j in range(layer_num):
                norm_layer = (
                    build_norm_layer(norm_cfg, num_filters[i])[1]
                    if norm_cfg is not None else Empty)
                block.append(
                    Conv2d(num_filters[i], num_filters[i], 3, padding=1))
                block.append(norm_layer)
                block.append(nn.ReLU(inplace=True))

            block = nn.Sequential(*block)
            blocks.append(block)

        self.blocks = nn.ModuleList(blocks)

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
zhangwenwei's avatar
zhangwenwei committed
75
            from mmdet3d.utils import get_root_logger
zhangwenwei's avatar
zhangwenwei committed
76
77
78
79
80
81
82
83
84
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

    def forward(self, x):
        outs = []
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
            outs.append(x)
        return tuple(outs)