multi_backbone.py 4.49 KB
Newer Older
encore-zhou's avatar
encore-zhou committed
1
2
3
import copy
import torch
from mmcv.cnn import ConvModule
4
from mmcv.runner import auto_fp16, load_checkpoint
encore-zhou's avatar
encore-zhou committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from torch import nn as nn

from mmdet.models import BACKBONES, build_backbone


@BACKBONES.register_module()
class MultiBackbone(nn.Module):
    """MultiBackbone with different configs.

    Args:
        num_streams (int): The number of backbones.
        backbones (list or dict): A list of backbone configs.
        aggregation_mlp_channels (list[int]): Specify the mlp layers
            for feature aggregation.
        conv_cfg (dict): Config dict of convolutional layers.
        norm_cfg (dict): Config dict of normalization layers.
        act_cfg (dict): Config dict of activation layers.
        suffixes (list): A list of suffixes to rename the return dict
            for each backbone.
    """

    def __init__(self,
                 num_streams,
                 backbones,
                 aggregation_mlp_channels=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
                 act_cfg=dict(type='ReLU'),
                 suffixes=('net0', 'net1'),
                 **kwargs):
        super().__init__()
        assert isinstance(backbones, dict) or isinstance(backbones, list)
        if isinstance(backbones, dict):
            backbones_list = []
            for ind in range(num_streams):
                backbones_list.append(copy.deepcopy(backbones))
            backbones = backbones_list

        assert len(backbones) == num_streams
        assert len(suffixes) == num_streams

        self.backbone_list = nn.ModuleList()
        # Rename the ret_dict with different suffixs.
        self.suffixes = suffixes

        out_channels = 0

        for backbone_cfg in backbones:
            out_channels += backbone_cfg['fp_channels'][-1][-1]
            self.backbone_list.append(build_backbone(backbone_cfg))

        # Feature aggregation layers
        if aggregation_mlp_channels is None:
            aggregation_mlp_channels = [
                out_channels, out_channels // 2,
                out_channels // len(self.backbone_list)
            ]
        else:
            aggregation_mlp_channels.insert(0, out_channels)

        self.aggregation_layers = nn.Sequential()
        for i in range(len(aggregation_mlp_channels) - 1):
            self.aggregation_layers.add_module(
                f'layer{i}',
                ConvModule(
                    aggregation_mlp_channels[i],
                    aggregation_mlp_channels[i + 1],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
77
                    bias=True,
encore-zhou's avatar
encore-zhou committed
78
79
80
81
82
83
84
85
86
87
88
                    inplace=True))

    def init_weights(self, pretrained=None):
        """Initialize the weights of PointNet++ backbone."""
        # Do not initialize the conv layers
        # to follow the original implementation
        if isinstance(pretrained, str):
            from mmdet3d.utils import get_root_logger
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

89
    @auto_fp16()
encore-zhou's avatar
encore-zhou committed
90
91
92
93
94
95
96
97
98
99
100
    def forward(self, points):
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            dict[str, list[torch.Tensor]]: Outputs from multiple backbones.

                - fp_xyz[suffix] (list[torch.Tensor]): The coordinates of
zhangwenwei's avatar
zhangwenwei committed
101
                  each fp features.
encore-zhou's avatar
encore-zhou committed
102
                - fp_features[suffix] (list[torch.Tensor]): The features
zhangwenwei's avatar
zhangwenwei committed
103
                  from each Feature Propagate Layers.
encore-zhou's avatar
encore-zhou committed
104
                - fp_indices[suffix] (list[torch.Tensor]): Indices of the
zhangwenwei's avatar
zhangwenwei committed
105
                  input points.
encore-zhou's avatar
encore-zhou committed
106
                - hd_feature (torch.Tensor): The aggregation feature
zhangwenwei's avatar
zhangwenwei committed
107
                  from multiple backbones.
encore-zhou's avatar
encore-zhou committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        """
        ret = {}
        fp_features = []
        for ind in range(len(self.backbone_list)):
            cur_ret = self.backbone_list[ind](points)
            cur_suffix = self.suffixes[ind]
            fp_features.append(cur_ret['fp_features'][-1])
            if cur_suffix != '':
                for k in cur_ret.keys():
                    cur_ret[k + '_' + cur_suffix] = cur_ret.pop(k)
            ret.update(cur_ret)

        # Combine the features here
        hd_feature = torch.cat(fp_features, dim=1)
        hd_feature = self.aggregation_layers(hd_feature)
        ret['hd_feature'] = hd_feature
        return ret