dgcnn_head.py 2.05 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Tuple

4
from mmcv.cnn.bricks import ConvModule
5
from torch import Tensor
6

zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.models.layers import DGCNNFPModule
8
from mmdet3d.registry import MODELS
9
10
11
from .decode_head import Base3DDecodeHead


12
@MODELS.register_module()
13
14
15
16
17
18
19
20
21
22
23
24
class DGCNNHead(Base3DDecodeHead):
    r"""DGCNN decoder head.

    Decoder head used in `DGCNN <https://arxiv.org/abs/1801.07829>`_.
    Refer to the
    `reimplementation code <https://github.com/AnTao97/dgcnn.pytorch>`_.

    Args:
        fp_channels (tuple[int], optional): Tuple of mlp channels in feature
            propagation (FP) modules. Defaults to (1216, 512).
    """

25
    def __init__(self, fp_channels: Tuple = (1216, 512), **kwargs) -> None:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        super(DGCNNHead, self).__init__(**kwargs)

        self.FP_module = DGCNNFPModule(
            mlp_channels=fp_channels, act_cfg=self.act_cfg)

        # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
        self.pre_seg_conv = ConvModule(
            fp_channels[-1],
            self.channels,
            kernel_size=1,
            bias=False,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

41
    def _extract_input(self, feat_dict: dict) -> Tensor:
42
43
44
45
46
47
48
49
50
51
52
53
        """Extract inputs from features dictionary.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
            torch.Tensor: points for decoder.
        """
        fa_points = feat_dict['fa_points']

        return fa_points

54
    def forward(self, feat_dict: dict) -> Tensor:
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        """Forward pass.

        Args:
            feat_dict (dict): Feature dict from backbone.

        Returns:
            torch.Tensor: Segmentation map of shape [B, num_classes, N].
        """
        fa_points = self._extract_input(feat_dict)

        fp_points = self.FP_module(fa_points)
        fp_points = fp_points.transpose(1, 2).contiguous()
        output = self.pre_seg_conv(fp_points)
        output = self.cls_seg(output)

        return output