second_fpn.py 2.53 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
3
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
                      is_norm, kaiming_init)
zhangwenwei's avatar
zhangwenwei committed
4
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
5

zhangwenwei's avatar
zhangwenwei committed
6
from mmdet.models import NECKS
zhangwenwei's avatar
zhangwenwei committed
7
8


9
@NECKS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class SECONDFPN(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
11
    """FPN used in SECOND/PointPillars/PartA2/MVXNet.
zhangwenwei's avatar
zhangwenwei committed
12
13
14
15
16
17
18

    Args:
        in_channels (list[int]): Input channels of multi-scale feature maps
        out_channels (list[int]): Output channels of feature maps
        upsample_strides (list[int]): Strides used to upsample the feature maps
        norm_cfg (dict): Config dict of normalization layers
        upsample_cfg (dict): Config dict of upsample layers
zhangwenwei's avatar
zhangwenwei committed
19
20
21
22
    """

    def __init__(self,
                 in_channels=[128, 128, 256],
zhangwenwei's avatar
zhangwenwei committed
23
                 out_channels=[256, 256, 256],
zhangwenwei's avatar
zhangwenwei committed
24
                 upsample_strides=[1, 2, 4],
zhangwenwei's avatar
zhangwenwei committed
25
26
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
                 upsample_cfg=dict(type='deconv', bias=False)):
zhangwenwei's avatar
zhangwenwei committed
27
28
29
        # if for GroupNorm,
        # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
        super(SECONDFPN, self).__init__()
zhangwenwei's avatar
zhangwenwei committed
30
        assert len(out_channels) == len(upsample_strides) == len(in_channels)
zhangwenwei's avatar
zhangwenwei committed
31
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
32
        self.out_channels = out_channels
zhangwenwei's avatar
zhangwenwei committed
33
34

        deblocks = []
zhangwenwei's avatar
zhangwenwei committed
35
        for i, out_channel in enumerate(out_channels):
zhangwenwei's avatar
zhangwenwei committed
36
37
            upsample_layer = build_upsample_layer(
                upsample_cfg,
zhangwenwei's avatar
zhangwenwei committed
38
39
                in_channels=in_channels[i],
                out_channels=out_channel,
zhangwenwei's avatar
zhangwenwei committed
40
                kernel_size=upsample_strides[i],
zhangwenwei's avatar
zhangwenwei committed
41
                stride=upsample_strides[i])
zhangwenwei's avatar
zhangwenwei committed
42
43
44
            deblock = nn.Sequential(upsample_layer,
                                    build_norm_layer(norm_cfg, out_channel)[1],
                                    nn.ReLU(inplace=True))
zhangwenwei's avatar
zhangwenwei committed
45
46
47
            deblocks.append(deblock)
        self.deblocks = nn.ModuleList(deblocks)

zhangwenwei's avatar
zhangwenwei committed
48
    def init_weights(self):
zhangwenwei's avatar
zhangwenwei committed
49
        """Initialize weights of FPN."""
zhangwenwei's avatar
zhangwenwei committed
50
51
52
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
zhangwenwei's avatar
zhangwenwei committed
53
            elif is_norm(m):
zhangwenwei's avatar
zhangwenwei committed
54
                constant_init(m, 1)
zhangwenwei's avatar
zhangwenwei committed
55

zhangwenwei's avatar
zhangwenwei committed
56
    def forward(self, x):
zhangwenwei's avatar
zhangwenwei committed
57
        """Forward function.
zhangwenwei's avatar
zhangwenwei committed
58

zhangwenwei's avatar
zhangwenwei committed
59
60
        Args:
            x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
zhangwenwei's avatar
zhangwenwei committed
61

zhangwenwei's avatar
zhangwenwei committed
62
63
64
        Returns:
            list[torch.Tensor]: Multi-level feature maps.
        """
zhangwenwei's avatar
zhangwenwei committed
65
66
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
67
68

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
69
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
70
        else:
zhangwenwei's avatar
zhangwenwei committed
71
72
            out = ups[0]
        return [out]