second_fpn.py 3.96 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
from functools import partial

import torch
import torch.nn as nn
5
from mmcv.cnn import build_norm_layer, constant_init, kaiming_init
zhangwenwei's avatar
zhangwenwei committed
6
7
8
from torch.nn import Sequential
from torch.nn.modules.batchnorm import _BatchNorm

zhangwenwei's avatar
zhangwenwei committed
9
from mmdet.models import NECKS
zhangwenwei's avatar
zhangwenwei committed
10
11
12
from .. import builder


13
@NECKS.register_module()
zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class SECONDFPN(nn.Module):
    """Compare with RPN, RPNV2 support arbitrary number of stage.
    """

    def __init__(self,
                 use_norm=True,
                 in_channels=[128, 128, 256],
                 upsample_strides=[1, 2, 4],
                 num_upsample_filters=[256, 256, 256],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)):
        # if for GroupNorm,
        # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
        super(SECONDFPN, self).__init__()
        assert len(num_upsample_filters) == len(upsample_strides)
        self.in_channels = in_channels

zhangwenwei's avatar
zhangwenwei committed
30
        ConvTranspose2d = partial(nn.ConvTranspose2d, bias=False)
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34

        deblocks = []

        for i, num_upsample_filter in enumerate(num_upsample_filters):
zhangwenwei's avatar
zhangwenwei committed
35
            norm_layer = build_norm_layer(norm_cfg, num_upsample_filter)[1]
zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41
42
43
44
45
46
47
            deblock = Sequential(
                ConvTranspose2d(
                    in_channels[i],
                    num_upsample_filter,
                    upsample_strides[i],
                    stride=upsample_strides[i]),
                norm_layer,
                nn.ReLU(inplace=True),
            )
            deblocks.append(deblock)
        self.deblocks = nn.ModuleList(deblocks)

zhangwenwei's avatar
zhangwenwei committed
48
49
50
51
52
53
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
            elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                constant_init(m, 1)
zhangwenwei's avatar
zhangwenwei committed
54

zhangwenwei's avatar
zhangwenwei committed
55
56
57
    def forward(self, x):
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
58
59

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
60
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
61
        else:
zhangwenwei's avatar
zhangwenwei committed
62
63
            out = ups[0]
        return [out]
zhangwenwei's avatar
zhangwenwei committed
64
65


66
@NECKS.register_module()
zhangwenwei's avatar
zhangwenwei committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class SECONDFusionFPN(SECONDFPN):
    """Compare with RPN, RPNV2 support arbitrary number of stage.
    """

    def __init__(self,
                 use_norm=True,
                 in_channels=[128, 128, 256],
                 upsample_strides=[1, 2, 4],
                 num_upsample_filters=[256, 256, 256],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
                 down_sample_rate=[40, 8, 8],
                 fusion_layer=None,
                 cat_points=False):
        super(SECONDFusionFPN, self).__init__(
            use_norm,
            in_channels,
            upsample_strides,
            num_upsample_filters,
            norm_cfg,
        )
        self.fusion_layer = None
        if fusion_layer is not None:
            self.fusion_layer = builder.build_fusion_layer(fusion_layer)
        self.cat_points = cat_points
        self.down_sample_rate = down_sample_rate

    def forward(self,
zhangwenwei's avatar
zhangwenwei committed
94
                x,
zhangwenwei's avatar
zhangwenwei committed
95
96
97
98
                coors=None,
                points=None,
                img_feats=None,
                img_meta=None):
zhangwenwei's avatar
zhangwenwei committed
99
100
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
101
102

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
103
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
104
        else:
zhangwenwei's avatar
zhangwenwei committed
105
            out = ups[0]
zhangwenwei's avatar
zhangwenwei committed
106
107
108
109
110
111
112
113
114
115
        if (self.fusion_layer is not None and img_feats is not None):
            downsample_pts_coors = torch.zeros_like(coors)
            downsample_pts_coors[:, 0] = coors[:, 0]
            downsample_pts_coors[:, 1] = (
                coors[:, 1] / self.down_sample_rate[0])
            downsample_pts_coors[:, 2] = (
                coors[:, 2] / self.down_sample_rate[1])
            downsample_pts_coors[:, 3] = (
                coors[:, 3] / self.down_sample_rate[2])
            # fusion for each point
zhangwenwei's avatar
zhangwenwei committed
116
117
118
            out = self.fusion_layer(img_feats, points, out,
                                    downsample_pts_coors, img_meta)
        return [out]