point_fp_module.py 2.94 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List

wuyuefeng's avatar
wuyuefeng committed
4
5
import torch
from mmcv.cnn import ConvModule
6
from mmcv.ops import three_interpolate, three_nn
7
from mmengine.model import BaseModule
8
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
9
from torch import nn as nn
wuyuefeng's avatar
wuyuefeng committed
10

11
12
from mmdet3d.utils import ConfigType, OptMultiConfig

wuyuefeng's avatar
wuyuefeng committed
13

14
class PointFPModule(BaseModule):
wuyuefeng's avatar
wuyuefeng committed
15
16
17
18
19
20
    """Point feature propagation module used in PointNets.

    Propagate the features from one set to another.

    Args:
        mlp_channels (list[int]): List of mlp channels.
21
22
23
24
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN2d').
        init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
            optional): Initialization config dict. Defaults to None.
wuyuefeng's avatar
wuyuefeng committed
25
26
27
28
    """

    def __init__(self,
                 mlp_channels: List[int],
29
30
31
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 init_cfg: OptMultiConfig = None) -> None:
        super(PointFPModule, self).__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
32
33
34
35
36
37
38
39
40
41
42
43
        self.mlps = nn.Sequential()
        for i in range(len(mlp_channels) - 1):
            self.mlps.add_module(
                f'layer{i}',
                ConvModule(
                    mlp_channels[i],
                    mlp_channels[i + 1],
                    kernel_size=(1, 1),
                    stride=(1, 1),
                    conv_cfg=dict(type='Conv2d'),
                    norm_cfg=norm_cfg))

44
45
46
    def forward(self, target: Tensor, source: Tensor, target_feats: Tensor,
                source_feats: Tensor) -> Tensor:
        """Forward.
wuyuefeng's avatar
wuyuefeng committed
47
48

        Args:
49
            target (Tensor): (B, n, 3) Tensor of the xyz positions of
wuyuefeng's avatar
wuyuefeng committed
50
                the target features.
51
            source (Tensor): (B, m, 3) Tensor of the xyz positions of
wuyuefeng's avatar
wuyuefeng committed
52
                the source features.
53
            target_feats (Tensor): (B, C1, n) Tensor of the features to be
wuyuefeng's avatar
wuyuefeng committed
54
                propagated to.
55
            source_feats (Tensor): (B, C2, m) Tensor of features
wuyuefeng's avatar
wuyuefeng committed
56
57
58
                to be propagated.

        Return:
59
            Tensor: (B, M, N) M = mlp[-1], Tensor of the target features.
wuyuefeng's avatar
wuyuefeng committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        """
        if source is not None:
            dist, idx = three_nn(target, source)
            dist_reciprocal = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
            weight = dist_reciprocal / norm

            interpolated_feats = three_interpolate(source_feats, idx, weight)
        else:
            interpolated_feats = source_feats.expand(*source_feats.size()[0:2],
                                                     target.size(1))

        if target_feats is not None:
            new_features = torch.cat([interpolated_feats, target_feats],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlps(new_features)

        return new_features.squeeze(-1)