"llama/ggml-cpu-aarch64.cpp" did not exist on "f2890a4494f9fb3722ee7a4c506252362d1eab65"
point_fp_module.py 2.76 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List

wuyuefeng's avatar
wuyuefeng committed
4
5
import torch
from mmcv.cnn import ConvModule
6
from mmcv.runner import BaseModule, force_fp32
zhangwenwei's avatar
zhangwenwei committed
7
from torch import nn as nn
wuyuefeng's avatar
wuyuefeng committed
8
9
10
11

from mmdet3d.ops import three_interpolate, three_nn


12
class PointFPModule(BaseModule):
wuyuefeng's avatar
wuyuefeng committed
13
14
15
16
17
18
    """Point feature propagation module used in PointNets.

    Propagate the features from one set to another.

    Args:
        mlp_channels (list[int]): List of mlp channels.
19
        norm_cfg (dict, optional): Type of normalization method.
wuyuefeng's avatar
wuyuefeng committed
20
21
22
23
24
            Default: dict(type='BN2d').
    """

    def __init__(self,
                 mlp_channels: List[int],
25
26
27
                 norm_cfg: dict = dict(type='BN2d'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
28
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
29
30
31
32
33
34
35
36
37
38
39
40
        self.mlps = nn.Sequential()
        for i in range(len(mlp_channels) - 1):
            self.mlps.add_module(
                f'layer{i}',
                ConvModule(
                    mlp_channels[i],
                    mlp_channels[i + 1],
                    kernel_size=(1, 1),
                    stride=(1, 1),
                    conv_cfg=dict(type='Conv2d'),
                    norm_cfg=norm_cfg))

41
    @force_fp32()
wuyuefeng's avatar
wuyuefeng committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def forward(self, target: torch.Tensor, source: torch.Tensor,
                target_feats: torch.Tensor,
                source_feats: torch.Tensor) -> torch.Tensor:
        """forward.

        Args:
            target (Tensor): (B, n, 3) tensor of the xyz positions of
                the target features.
            source (Tensor): (B, m, 3) tensor of the xyz positions of
                the source features.
            target_feats (Tensor): (B, C1, n) tensor of the features to be
                propagated to.
            source_feats (Tensor): (B, C2, m) tensor of features
                to be propagated.

        Return:
            Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
        """
        if source is not None:
            dist, idx = three_nn(target, source)
            dist_reciprocal = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
            weight = dist_reciprocal / norm

            interpolated_feats = three_interpolate(source_feats, idx, weight)
        else:
            interpolated_feats = source_feats.expand(*source_feats.size()[0:2],
                                                     target.size(1))

        if target_feats is not None:
            new_features = torch.cat([interpolated_feats, target_feats],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlps(new_features)

        return new_features.squeeze(-1)