point_fp_module.py 2.7 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
import torch
from mmcv.cnn import ConvModule
3
from mmcv.runner import BaseModule, force_fp32
zhangwenwei's avatar
zhangwenwei committed
4
5
from torch import nn as nn
from typing import List
wuyuefeng's avatar
wuyuefeng committed
6
7
8
9

from mmdet3d.ops import three_interpolate, three_nn


10
class PointFPModule(BaseModule):
wuyuefeng's avatar
wuyuefeng committed
11
12
13
14
15
16
17
18
19
20
21
22
    """Point feature propagation module used in PointNets.

    Propagate the features from one set to another.

    Args:
        mlp_channels (list[int]): List of mlp channels.
        norm_cfg (dict): Type of normalization method.
            Default: dict(type='BN2d').
    """

    def __init__(self,
                 mlp_channels: List[int],
23
24
25
                 norm_cfg: dict = dict(type='BN2d'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
26
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
27
28
29
30
31
32
33
34
35
36
37
38
        self.mlps = nn.Sequential()
        for i in range(len(mlp_channels) - 1):
            self.mlps.add_module(
                f'layer{i}',
                ConvModule(
                    mlp_channels[i],
                    mlp_channels[i + 1],
                    kernel_size=(1, 1),
                    stride=(1, 1),
                    conv_cfg=dict(type='Conv2d'),
                    norm_cfg=norm_cfg))

39
    @force_fp32()
wuyuefeng's avatar
wuyuefeng committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def forward(self, target: torch.Tensor, source: torch.Tensor,
                target_feats: torch.Tensor,
                source_feats: torch.Tensor) -> torch.Tensor:
        """forward.

        Args:
            target (Tensor): (B, n, 3) tensor of the xyz positions of
                the target features.
            source (Tensor): (B, m, 3) tensor of the xyz positions of
                the source features.
            target_feats (Tensor): (B, C1, n) tensor of the features to be
                propagated to.
            source_feats (Tensor): (B, C2, m) tensor of features
                to be propagated.

        Return:
            Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
        """
        if source is not None:
            dist, idx = three_nn(target, source)
            dist_reciprocal = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
            weight = dist_reciprocal / norm

            interpolated_feats = three_interpolate(source_feats, idx, weight)
        else:
            interpolated_feats = source_feats.expand(*source_feats.size()[0:2],
                                                     target.size(1))

        if target_feats is not None:
            new_features = torch.cat([interpolated_feats, target_feats],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlps(new_features)

        return new_features.squeeze(-1)