multi_backbone.py 4.59 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
encore-zhou's avatar
encore-zhou committed
2
import copy
3
import warnings
4
5

import torch
encore-zhou's avatar
encore-zhou committed
6
from mmcv.cnn import ConvModule
7
from mmengine.model import BaseModule
encore-zhou's avatar
encore-zhou committed
8
9
from torch import nn as nn

10
from mmdet3d.registry import MODELS
encore-zhou's avatar
encore-zhou committed
11
12


13
@MODELS.register_module()
14
class MultiBackbone(BaseModule):
encore-zhou's avatar
encore-zhou committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    """MultiBackbone with different configs.

    Args:
        num_streams (int): The number of backbones.
        backbones (list or dict): A list of backbone configs.
        aggregation_mlp_channels (list[int]): Specify the mlp layers
            for feature aggregation.
        conv_cfg (dict): Config dict of convolutional layers.
        norm_cfg (dict): Config dict of normalization layers.
        act_cfg (dict): Config dict of activation layers.
        suffixes (list): A list of suffixes to rename the return dict
            for each backbone.
    """

    def __init__(self,
                 num_streams,
                 backbones,
                 aggregation_mlp_channels=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
                 act_cfg=dict(type='ReLU'),
                 suffixes=('net0', 'net1'),
37
38
                 init_cfg=None,
                 pretrained=None,
encore-zhou's avatar
encore-zhou committed
39
                 **kwargs):
40
        super().__init__(init_cfg=init_cfg)
encore-zhou's avatar
encore-zhou committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        assert isinstance(backbones, dict) or isinstance(backbones, list)
        if isinstance(backbones, dict):
            backbones_list = []
            for ind in range(num_streams):
                backbones_list.append(copy.deepcopy(backbones))
            backbones = backbones_list

        assert len(backbones) == num_streams
        assert len(suffixes) == num_streams

        self.backbone_list = nn.ModuleList()
        # Rename the ret_dict with different suffixs.
        self.suffixes = suffixes

        out_channels = 0

        for backbone_cfg in backbones:
            out_channels += backbone_cfg['fp_channels'][-1][-1]
59
            self.backbone_list.append(MODELS.build(backbone_cfg))
encore-zhou's avatar
encore-zhou committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        # Feature aggregation layers
        if aggregation_mlp_channels is None:
            aggregation_mlp_channels = [
                out_channels, out_channels // 2,
                out_channels // len(self.backbone_list)
            ]
        else:
            aggregation_mlp_channels.insert(0, out_channels)

        self.aggregation_layers = nn.Sequential()
        for i in range(len(aggregation_mlp_channels) - 1):
            self.aggregation_layers.add_module(
                f'layer{i}',
                ConvModule(
                    aggregation_mlp_channels[i],
                    aggregation_mlp_channels[i + 1],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
82
                    bias=True,
encore-zhou's avatar
encore-zhou committed
83
84
                    inplace=True))

85
86
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
encore-zhou's avatar
encore-zhou committed
87
        if isinstance(pretrained, str):
88
89
90
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
encore-zhou's avatar
encore-zhou committed
91
92
93
94
95
96
97
98
99
100
101
102

    def forward(self, points):
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            dict[str, list[torch.Tensor]]: Outputs from multiple backbones.

                - fp_xyz[suffix] (list[torch.Tensor]): The coordinates of
zhangwenwei's avatar
zhangwenwei committed
103
                  each fp features.
encore-zhou's avatar
encore-zhou committed
104
                - fp_features[suffix] (list[torch.Tensor]): The features
zhangwenwei's avatar
zhangwenwei committed
105
                  from each Feature Propagate Layers.
encore-zhou's avatar
encore-zhou committed
106
                - fp_indices[suffix] (list[torch.Tensor]): Indices of the
zhangwenwei's avatar
zhangwenwei committed
107
                  input points.
encore-zhou's avatar
encore-zhou committed
108
                - hd_feature (torch.Tensor): The aggregation feature
zhangwenwei's avatar
zhangwenwei committed
109
                  from multiple backbones.
encore-zhou's avatar
encore-zhou committed
110
111
112
113
114
115
116
        """
        ret = {}
        fp_features = []
        for ind in range(len(self.backbone_list)):
            cur_ret = self.backbone_list[ind](points)
            cur_suffix = self.suffixes[ind]
            fp_features.append(cur_ret['fp_features'][-1])
VVsssssk's avatar
VVsssssk committed
117
            cur_ret_new = dict()
encore-zhou's avatar
encore-zhou committed
118
119
            if cur_suffix != '':
                for k in cur_ret.keys():
VVsssssk's avatar
VVsssssk committed
120
121
                    cur_ret_new[k + '_' + cur_suffix] = cur_ret[k]
            ret.update(cur_ret_new)
encore-zhou's avatar
encore-zhou committed
122
123
124
125
126
127

        # Combine the features here
        hd_feature = torch.cat(fp_features, dim=1)
        hd_feature = self.aggregation_layers(hd_feature)
        ret['hd_feature'] = hd_feature
        return ret