paconv_head.py 2.72 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Sequence
3

4
from mmcv.cnn.bricks import ConvModule
5
from torch import Tensor
6

7
from mmdet3d.registry import MODELS
8
from mmdet3d.utils.typing_utils import ConfigType
9
10
11
from .pointnet2_head import PointNet2Head


12
@MODELS.register_module()
13
14
15
16
17
18
19
class PAConvHead(PointNet2Head):
    r"""PAConv decoder head.

    Decoder head used in `PAConv <https://arxiv.org/abs/2103.14635>`_.
    Refer to the `official code <https://github.com/CVMI-Lab/PAConv>`_.

    Args:
20
21
22
23
24
        fp_channels (Sequence[Sequence[int]]): Tuple of mlp channels in FP
            modules. Defaults to ((768, 256, 256), (384, 256, 256),
            (320, 256, 128), (128 + 6, 128, 128, 128)).
        fp_norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers used in
            FP modules. Defaults to dict(type='BN2d').
25
26
27
    """

    def __init__(self,
28
29
30
31
32
                 fp_channels: Sequence[Sequence[int]] = ((768, 256, 256),
                                                         (384, 256, 256),
                                                         (320, 256,
                                                          128), (128 + 6, 128,
                                                                 128, 128)),
33
34
35
36
                 fp_norm_cfg: ConfigType = dict(type='BN2d'),
                 **kwargs) -> None:
        super(PAConvHead, self).__init__(
            fp_channels=fp_channels, fp_norm_cfg=fp_norm_cfg, **kwargs)
37
38
39
40
41
42
43
44
45
46
47
48
49

        # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53
        # PointNet++'s decoder conv has bias while PAConv's doesn't have
        # so we need to rebuild it here
        self.pre_seg_conv = ConvModule(
            fp_channels[-1][-1],
            self.channels,
            kernel_size=1,
            bias=False,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

50
    def forward(self, feat_dict: dict) -> Tensor:
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        """Forward pass.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
            torch.Tensor: Segmentation map of shape [B, num_classes, N].
        """
        sa_xyz, sa_features = self._extract_input(feat_dict)

        # PointNet++ doesn't use the first level of `sa_features` as input
        # while PAConv inputs it through skip-connection
        fp_feature = sa_features[-1]

        for i in range(self.num_fp):
            # consume the points in a bottom-up manner
            fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
                                            sa_features[-(i + 2)], fp_feature)

        output = self.pre_seg_conv(fp_feature)
        output = self.cls_seg(output)

        return output