pointnet2_head.py 3.32 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import List, Sequence, Tuple
3

4
from mmcv.cnn.bricks import ConvModule
5
from torch import Tensor
6
7
from torch import nn as nn

zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.models.layers import PointFPModule
9
from mmdet3d.registry import MODELS
10
from mmdet3d.utils.typing_utils import ConfigType
11
12
13
from .decode_head import Base3DDecodeHead


14
@MODELS.register_module()
15
16
17
18
19
20
21
class PointNet2Head(Base3DDecodeHead):
    r"""PointNet2 decoder head.

    Decoder head used in `PointNet++ <https://arxiv.org/abs/1706.02413>`_.
    Refer to the `official code <https://github.com/charlesq34/pointnet2>`_.

    Args:
22
23
24
25
26
        fp_channels (Sequence[Sequence[int]]): Tuple of mlp channels in FP
            modules. Defaults to ((768, 256, 256), (384, 256, 256),
            (320, 256, 128), (128, 128, 128, 128)).
        fp_norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers used
            in FP modules. Defaults to dict(type='BN2d').
27
28
29
    """

    def __init__(self,
30
31
32
33
                 fp_channels: Sequence[Sequence[int]] = ((768, 256, 256),
                                                         (384, 256, 256),
                                                         (320, 256, 128),
                                                         (128, 128, 128, 128)),
34
35
                 fp_norm_cfg: ConfigType = dict(type='BN2d'),
                 **kwargs) -> None:
36
37
38
39
40
        super(PointNet2Head, self).__init__(**kwargs)

        self.num_fp = len(fp_channels)
        self.FP_modules = nn.ModuleList()
        for cur_fp_mlps in fp_channels:
41
42
            self.FP_modules.append(
                PointFPModule(mlp_channels=cur_fp_mlps, norm_cfg=fp_norm_cfg))
43
44
45
46
47
48
49
50
51
52
53

        # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
        self.pre_seg_conv = ConvModule(
            fp_channels[-1][-1],
            self.channels,
            kernel_size=1,
            bias=True,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

54
55
    def _extract_input(self,
                       feat_dict: dict) -> Tuple[List[Tensor], List[Tensor]]:
56
57
58
59
60
61
        """Extract inputs from features dictionary.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
62
63
            Tuple[List[Tensor], List[Tensor]]: Coordinates and features of
            multiple levels of points.
64
65
66
67
68
69
70
        """
        sa_xyz = feat_dict['sa_xyz']
        sa_features = feat_dict['sa_features']
        assert len(sa_xyz) == len(sa_features)

        return sa_xyz, sa_features

71
    def forward(self, feat_dict: dict) -> Tensor:
72
73
74
75
76
77
        """Forward pass.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
78
            Tensor: Segmentation map of shape [B, num_classes, N].
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        """
        sa_xyz, sa_features = self._extract_input(feat_dict)

        # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L24
        sa_features[0] = None

        fp_feature = sa_features[-1]

        for i in range(self.num_fp):
            # consume the points in a bottom-up manner
            fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
                                            sa_features[-(i + 2)], fp_feature)
        output = self.pre_seg_conv(fp_feature)
        output = self.cls_seg(output)

        return output