second_fpn.py 4.8 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import copy
zhangwenwei's avatar
zhangwenwei committed
2
3
4

import torch
import torch.nn as nn
zhangwenwei's avatar
zhangwenwei committed
5
6
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
                      is_norm, kaiming_init)
zhangwenwei's avatar
zhangwenwei committed
7

zhangwenwei's avatar
zhangwenwei committed
8
from mmdet.models import NECKS
zhangwenwei's avatar
zhangwenwei committed
9
10
11
from .. import builder


12
@NECKS.register_module()
zhangwenwei's avatar
zhangwenwei committed
13
class SECONDFPN(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
14
15
16
17
18
19
20
21
    """FPN used in SECOND/PointPillars

    Args:
        in_channels (list[int]): Input channels of multi-scale feature maps
        out_channels (list[int]): Output channels of feature maps
        upsample_strides (list[int]): Strides used to upsample the feature maps
        norm_cfg (dict): Config dict of normalization layers
        upsample_cfg (dict): Config dict of upsample layers
zhangwenwei's avatar
zhangwenwei committed
22
23
24
25
    """

    def __init__(self,
                 in_channels=[128, 128, 256],
zhangwenwei's avatar
zhangwenwei committed
26
                 out_channels=[256, 256, 256],
zhangwenwei's avatar
zhangwenwei committed
27
                 upsample_strides=[1, 2, 4],
zhangwenwei's avatar
zhangwenwei committed
28
29
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
                 upsample_cfg=dict(type='deconv', bias=False)):
zhangwenwei's avatar
zhangwenwei committed
30
31
32
        # if for GroupNorm,
        # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
        super(SECONDFPN, self).__init__()
zhangwenwei's avatar
zhangwenwei committed
33
        assert len(out_channels) == len(upsample_strides) == len(in_channels)
zhangwenwei's avatar
zhangwenwei committed
34
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
35
        self.out_channels = out_channels
zhangwenwei's avatar
zhangwenwei committed
36
37

        deblocks = []
zhangwenwei's avatar
zhangwenwei committed
38
39
40
41
42
43
44
45
46
47
48
        for i, out_channel in enumerate(out_channels):
            norm_layer = build_norm_layer(norm_cfg, out_channel)[1]
            upsample_cfg_ = copy.deepcopy(upsample_cfg)
            upsample_cfg_.update(
                in_channels=in_channels[i],
                out_channels=out_channel,
                padding=upsample_strides[i],
                stride=upsample_strides[i])
            upsample_layer = build_upsample_layer(upsample_cfg_)
            deblock = nn.Sequential(
                upsample_layer,
zhangwenwei's avatar
zhangwenwei committed
49
50
51
52
53
54
                norm_layer,
                nn.ReLU(inplace=True),
            )
            deblocks.append(deblock)
        self.deblocks = nn.ModuleList(deblocks)

zhangwenwei's avatar
zhangwenwei committed
55
56
57
58
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
zhangwenwei's avatar
zhangwenwei committed
59
            elif is_norm(m):
zhangwenwei's avatar
zhangwenwei committed
60
                constant_init(m, 1)
zhangwenwei's avatar
zhangwenwei committed
61

zhangwenwei's avatar
zhangwenwei committed
62
63
64
    def forward(self, x):
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
65
66

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
67
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
68
        else:
zhangwenwei's avatar
zhangwenwei committed
69
70
            out = ups[0]
        return [out]
zhangwenwei's avatar
zhangwenwei committed
71
72


73
@NECKS.register_module()
zhangwenwei's avatar
zhangwenwei committed
74
class SECONDFusionFPN(SECONDFPN):
zhangwenwei's avatar
zhangwenwei committed
75
76
77
78
79
80
81
82
83
84
85
    """FPN used in multi-modality SECOND/PointPillars

    Args:
        in_channels (list[int]): Input channels of multi-scale feature maps
        out_channels (list[int]): Output channels of feature maps
        upsample_strides (list[int]): Strides used to upsample the feature maps
        norm_cfg (dict): Config dict of normalization layers
        upsample_cfg (dict): Config dict of upsample layers
        downsample_rates (list[int]): The downsample rate of feature map in
            comparison to the original voxelization input
        fusion_layer (dict): Config dict of fusion layers
zhangwenwei's avatar
zhangwenwei committed
86
87
88
89
    """

    def __init__(self,
                 in_channels=[128, 128, 256],
zhangwenwei's avatar
zhangwenwei committed
90
                 out_channels=[256, 256, 256],
zhangwenwei's avatar
zhangwenwei committed
91
92
                 upsample_strides=[1, 2, 4],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
zhangwenwei's avatar
zhangwenwei committed
93
94
95
96
97
98
                 upsample_cfg=dict(type='deconv', bias=False),
                 downsample_rates=[40, 8, 8],
                 fusion_layer=None):
        super(SECONDFusionFPN,
              self).__init__(in_channels, out_channels, upsample_strides,
                             norm_cfg, upsample_cfg)
zhangwenwei's avatar
zhangwenwei committed
99
100
101
        self.fusion_layer = None
        if fusion_layer is not None:
            self.fusion_layer = builder.build_fusion_layer(fusion_layer)
zhangwenwei's avatar
zhangwenwei committed
102
        self.downsample_rates = downsample_rates
zhangwenwei's avatar
zhangwenwei committed
103
104

    def forward(self,
zhangwenwei's avatar
zhangwenwei committed
105
                x,
zhangwenwei's avatar
zhangwenwei committed
106
107
108
109
                coors=None,
                points=None,
                img_feats=None,
                img_meta=None):
zhangwenwei's avatar
zhangwenwei committed
110
111
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
112
113

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
114
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
115
        else:
zhangwenwei's avatar
zhangwenwei committed
116
            out = ups[0]
zhangwenwei's avatar
zhangwenwei committed
117
118
119
120
        if (self.fusion_layer is not None and img_feats is not None):
            downsample_pts_coors = torch.zeros_like(coors)
            downsample_pts_coors[:, 0] = coors[:, 0]
            downsample_pts_coors[:, 1] = (
zhangwenwei's avatar
zhangwenwei committed
121
                coors[:, 1] / self.downsample_rates[0])
zhangwenwei's avatar
zhangwenwei committed
122
            downsample_pts_coors[:, 2] = (
zhangwenwei's avatar
zhangwenwei committed
123
                coors[:, 2] / self.downsample_rates[1])
zhangwenwei's avatar
zhangwenwei committed
124
            downsample_pts_coors[:, 3] = (
zhangwenwei's avatar
zhangwenwei committed
125
                coors[:, 3] / self.downsample_rates[2])
zhangwenwei's avatar
zhangwenwei committed
126
            # fusion for each point
zhangwenwei's avatar
zhangwenwei committed
127
128
129
            out = self.fusion_layer(img_feats, points, out,
                                    downsample_pts_coors, img_meta)
        return [out]