"official/benchmark/models/cifar_preprocessing.py" did not exist on "ca6f90261d1b2264fb0d5b149b621d1a39c2b617"
pointnet2_sa_msg.py 8.03 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Tuple

4
5
6
7
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn

zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.models.layers.pointnet_modules import build_sa_module
9
from mmdet3d.registry import MODELS
10
from mmdet3d.utils import OptConfigType
11
12
from .base_pointnet import BasePointNet

13
14
15
16
ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]]
TwoTupleIntType = Tuple[Tuple[int, int, int]]
TwoTupleStrType = Tuple[Tuple[str]]

17

18
@MODELS.register_module()
19
20
21
22
23
24
25
26
27
28
29
30
31
class PointNet2SAMSG(BasePointNet):
    """PointNet2 with Multi-scale grouping.

    Args:
        in_channels (int): Input channels of point cloud.
        num_points (tuple[int]): The number of points which each SA
            module samples.
        radii (tuple[float]): Sampling radii of each SA module.
        num_samples (tuple[int]): The number of samples for ball
            query in each SA module.
        sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
        aggregation_channels (tuple[int]): Out channels of aggregation
            multi-scale grouping features.
32
        fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module.
33
34
        fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
            points which each SA module samples.
35
        dilated_group (tuple[bool]): Whether to use dilated ball query for
36
37
38
39
40
41
42
43
44
45
46
47
        out_indices (Sequence[int]): Output from which stages.
        norm_cfg (dict): Config of normalization layer.
        sa_cfg (dict): Config of set abstraction module, which may contain
            the following keys and values:

            - pool_mod (str): Pool method ('max' or 'avg') for SA modules.
            - use_xyz (bool): Whether to use xyz as a part of features.
            - normalize_xyz (bool): Whether to normalize xyz with radii in
              each SA module.
    """

    def __init__(self,
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
                 in_channels: int,
                 num_points: Tuple[int] = (2048, 1024, 512, 256),
                 radii: Tuple[Tuple[float, float, float]] = (
                     (0.2, 0.4, 0.8),
                     (0.4, 0.8, 1.6),
                     (1.6, 3.2, 4.8),
                 ),
                 num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64),
                                                 (32, 32, 32)),
                 sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32),
                                                    (32, 32, 64)),
                                                   ((64, 64, 128),
                                                    (64, 64, 128), (64, 96,
                                                                    128)),
                                                   ((128, 128, 256),
                                                    (128, 192, 256), (128, 256,
                                                                      256))),
                 aggregation_channels: Tuple[int] = (64, 128, 256),
                 fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS',
                                                                  'D-FPS')),
                 fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512,
                                                                         -1)),
                 dilated_group: Tuple[bool] = (True, True, True),
                 out_indices: Tuple[int] = (2, ),
                 norm_cfg: dict = dict(type='BN2d'),
                 sa_cfg: dict = dict(
74
75
76
                     type='PointSAModuleMSG',
                     pool_mod='max',
                     use_xyz=True,
77
                     normalize_xyz=False),
78
                 init_cfg: OptConfigType = None):
79
        super().__init__(init_cfg=init_cfg)
80
81
82
83
        self.num_sa = len(sa_channels)
        self.out_indices = out_indices
        assert max(out_indices) < self.num_sa
        assert len(num_points) == len(radii) == len(num_samples) == len(
84
85
86
87
88
            sa_channels)
        if aggregation_channels is not None:
            assert len(sa_channels) == len(aggregation_channels)
        else:
            aggregation_channels = [None] * len(sa_channels)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        self.SA_modules = nn.ModuleList()
        self.aggregation_mlps = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            sa_out_channel = 0
            for radius_index in range(len(radii[sa_index])):
                cur_sa_mlps[radius_index] = [sa_in_channel] + list(
                    cur_sa_mlps[radius_index])
                sa_out_channel += cur_sa_mlps[radius_index][-1]

            if isinstance(fps_mods[sa_index], tuple):
                cur_fps_mod = list(fps_mods[sa_index])
            else:
                cur_fps_mod = list([fps_mods[sa_index]])

            if isinstance(fps_sample_range_lists[sa_index], tuple):
                cur_fps_sample_range_list = list(
                    fps_sample_range_lists[sa_index])
            else:
                cur_fps_sample_range_list = list(
                    [fps_sample_range_lists[sa_index]])

            self.SA_modules.append(
                build_sa_module(
                    num_point=num_points[sa_index],
                    radii=radii[sa_index],
                    sample_nums=num_samples[sa_index],
                    mlp_channels=cur_sa_mlps,
                    fps_mod=cur_fps_mod,
                    fps_sample_range_list=cur_fps_sample_range_list,
123
                    dilated_group=dilated_group[sa_index],
124
125
126
127
                    norm_cfg=norm_cfg,
                    cfg=sa_cfg,
                    bias=True))
            skip_channel_list.append(sa_out_channel)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

            cur_aggregation_channel = aggregation_channels[sa_index]
            if cur_aggregation_channel is None:
                self.aggregation_mlps.append(None)
                sa_in_channel = sa_out_channel
            else:
                self.aggregation_mlps.append(
                    ConvModule(
                        sa_out_channel,
                        cur_aggregation_channel,
                        conv_cfg=dict(type='Conv1d'),
                        norm_cfg=dict(type='BN1d'),
                        kernel_size=1,
                        bias=True))
                sa_in_channel = cur_aggregation_channel
143

144
    def forward(self, points: torch.Tensor):
145
146
147
148
149
150
151
152
153
154
155
156
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            dict[str, torch.Tensor]: Outputs of the last SA module.

                - sa_xyz (torch.Tensor): The coordinates of sa features.
                - sa_features (torch.Tensor): The features from the
                    last Set Aggregation Layers.
157
                - sa_indices (torch.Tensor): Indices of the
158
159
160
161
162
163
164
165
166
167
168
169
                    input points.
        """
        xyz, features = self._split_point_feats(points)

        batch, num_points = xyz.shape[:2]
        indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat(
            batch, 1).long()

        sa_xyz = [xyz]
        sa_features = [features]
        sa_indices = [indices]

170
171
172
        out_sa_xyz = [xyz]
        out_sa_features = [features]
        out_sa_indices = [indices]
173

174
175
176
        for i in range(self.num_sa):
            cur_xyz, cur_features, cur_indices = self.SA_modules[i](
                sa_xyz[i], sa_features[i])
177
178
            if self.aggregation_mlps[i] is not None:
                cur_features = self.aggregation_mlps[i](cur_features)
179
180
181
182
183
184
185
186
187
188
189
190
191
            sa_xyz.append(cur_xyz)
            sa_features.append(cur_features)
            sa_indices.append(
                torch.gather(sa_indices[-1], 1, cur_indices.long()))
            if i in self.out_indices:
                out_sa_xyz.append(sa_xyz[-1])
                out_sa_features.append(sa_features[-1])
                out_sa_indices.append(sa_indices[-1])

        return dict(
            sa_xyz=out_sa_xyz,
            sa_features=out_sa_features,
            sa_indices=out_sa_indices)