second.py 2.62 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from functools import partial

import torch.nn as nn
from mmcv.runner import load_checkpoint

from ..registry import BACKBONES
from ..utils import build_norm_layer


class Empty(nn.Module):

    def __init__(self, *args, **kwargs):
        super(Empty, self).__init__()

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            return args[0]
        elif len(args) == 0:
            return None
        return args


@BACKBONES.register_module
class SECOND(nn.Module):
    """Compare with RPN, RPNV2 support arbitrary number of stage.
    """

    def __init__(self,
                 in_channels=128,
                 layer_nums=[3, 5, 5],
                 layer_strides=[2, 2, 2],
                 num_filters=[128, 128, 256],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)):
        super(SECOND, self).__init__()
        assert len(layer_strides) == len(layer_nums)
        assert len(num_filters) == len(layer_nums)

        if norm_cfg is not None:
            Conv2d = partial(nn.Conv2d, bias=False)
        else:
            Conv2d = partial(nn.Conv2d, bias=True)

        in_filters = [in_channels, *num_filters[:-1]]
        # note that when stride > 1, conv2d with same padding isn't
        # equal to pad-conv2d. we should use pad-conv2d.
        blocks = []

        for i, layer_num in enumerate(layer_nums):
            norm_layer = (
                build_norm_layer(norm_cfg, num_filters[i])[1]
                if norm_cfg is not None else Empty)
            block = [
                nn.ZeroPad2d(1),
                Conv2d(
                    in_filters[i], num_filters[i], 3, stride=layer_strides[i]),
                norm_layer,
                nn.ReLU(inplace=True),
            ]
            for j in range(layer_num):
                norm_layer = (
                    build_norm_layer(norm_cfg, num_filters[i])[1]
                    if norm_cfg is not None else Empty)
                block.append(
                    Conv2d(num_filters[i], num_filters[i], 3, padding=1))
                block.append(norm_layer)
                block.append(nn.ReLU(inplace=True))

            block = nn.Sequential(*block)
            blocks.append(block)

        self.blocks = nn.ModuleList(blocks)

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            from mmdet3d.apis import get_root_logger
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

    def forward(self, x):
        outs = []
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
            outs.append(x)
        return tuple(outs)