multi_backbone.py 4.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
encore-zhou's avatar
encore-zhou committed
2
import copy
3
import warnings
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5
6

import torch
encore-zhou's avatar
encore-zhou committed
7
from mmcv.cnn import ConvModule
8
from mmengine.model import BaseModule
9
from torch import Tensor, nn
encore-zhou's avatar
encore-zhou committed
10

11
from mmdet3d.registry import MODELS
12
from mmdet3d.utils import ConfigType, OptMultiConfig
encore-zhou's avatar
encore-zhou committed
13
14


15
@MODELS.register_module()
16
class MultiBackbone(BaseModule):
encore-zhou's avatar
encore-zhou committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    """MultiBackbone with different configs.

    Args:
        num_streams (int): The number of backbones.
        backbones (list or dict): A list of backbone configs.
        aggregation_mlp_channels (list[int]): Specify the mlp layers
            for feature aggregation.
        conv_cfg (dict): Config dict of convolutional layers.
        norm_cfg (dict): Config dict of normalization layers.
        act_cfg (dict): Config dict of activation layers.
        suffixes (list): A list of suffixes to rename the return dict
            for each backbone.
    """

    def __init__(self,
32
33
34
35
36
37
38
39
40
41
42
                 num_streams: int,
                 backbones: Union[List[dict], Dict],
                 aggregation_mlp_channels: Optional[Sequence[int]] = None,
                 conv_cfg: ConfigType = dict(type='Conv1d'),
                 norm_cfg: ConfigType = dict(
                     type='BN1d', eps=1e-5, momentum=0.01),
                 act_cfg: ConfigType = dict(type='ReLU'),
                 suffixes: Tuple[str] = ('net0', 'net1'),
                 init_cfg: OptMultiConfig = None,
                 pretrained: Optional[str] = None,
                 **kwargs) -> None:
43
        super().__init__(init_cfg=init_cfg)
encore-zhou's avatar
encore-zhou committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        assert isinstance(backbones, dict) or isinstance(backbones, list)
        if isinstance(backbones, dict):
            backbones_list = []
            for ind in range(num_streams):
                backbones_list.append(copy.deepcopy(backbones))
            backbones = backbones_list

        assert len(backbones) == num_streams
        assert len(suffixes) == num_streams

        self.backbone_list = nn.ModuleList()
        # Rename the ret_dict with different suffixs.
        self.suffixes = suffixes

        out_channels = 0

        for backbone_cfg in backbones:
            out_channels += backbone_cfg['fp_channels'][-1][-1]
62
            self.backbone_list.append(MODELS.build(backbone_cfg))
encore-zhou's avatar
encore-zhou committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        # Feature aggregation layers
        if aggregation_mlp_channels is None:
            aggregation_mlp_channels = [
                out_channels, out_channels // 2,
                out_channels // len(self.backbone_list)
            ]
        else:
            aggregation_mlp_channels.insert(0, out_channels)

        self.aggregation_layers = nn.Sequential()
        for i in range(len(aggregation_mlp_channels) - 1):
            self.aggregation_layers.add_module(
                f'layer{i}',
                ConvModule(
                    aggregation_mlp_channels[i],
                    aggregation_mlp_channels[i + 1],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
85
                    bias=True,
encore-zhou's avatar
encore-zhou committed
86
87
                    inplace=True))

88
89
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
encore-zhou's avatar
encore-zhou committed
90
        if isinstance(pretrained, str):
91
92
93
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
encore-zhou's avatar
encore-zhou committed
94

95
    def forward(self, points: Tensor) -> dict:
encore-zhou's avatar
encore-zhou committed
96
97
98
99
100
101
102
103
104
105
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            dict[str, list[torch.Tensor]]: Outputs from multiple backbones.

                - fp_xyz[suffix] (list[torch.Tensor]): The coordinates of
zhangwenwei's avatar
zhangwenwei committed
106
                  each fp features.
encore-zhou's avatar
encore-zhou committed
107
                - fp_features[suffix] (list[torch.Tensor]): The features
zhangwenwei's avatar
zhangwenwei committed
108
                  from each Feature Propagate Layers.
encore-zhou's avatar
encore-zhou committed
109
                - fp_indices[suffix] (list[torch.Tensor]): Indices of the
zhangwenwei's avatar
zhangwenwei committed
110
                  input points.
encore-zhou's avatar
encore-zhou committed
111
                - hd_feature (torch.Tensor): The aggregation feature
zhangwenwei's avatar
zhangwenwei committed
112
                  from multiple backbones.
encore-zhou's avatar
encore-zhou committed
113
114
115
116
117
118
119
        """
        ret = {}
        fp_features = []
        for ind in range(len(self.backbone_list)):
            cur_ret = self.backbone_list[ind](points)
            cur_suffix = self.suffixes[ind]
            fp_features.append(cur_ret['fp_features'][-1])
VVsssssk's avatar
VVsssssk committed
120
            cur_ret_new = dict()
encore-zhou's avatar
encore-zhou committed
121
122
            if cur_suffix != '':
                for k in cur_ret.keys():
VVsssssk's avatar
VVsssssk committed
123
124
                    cur_ret_new[k + '_' + cur_suffix] = cur_ret[k]
            ret.update(cur_ret_new)
encore-zhou's avatar
encore-zhou committed
125
126
127
128
129
130

        # Combine the features here
        hd_feature = torch.cat(fp_features, dim=1)
        hd_feature = self.aggregation_layers(hd_feature)
        ret['hd_feature'] = hd_feature
        return ret