base_pointnet.py 1.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
from abc import ABCMeta
from mmcv.runner import load_checkpoint
from torch import nn as nn


class BasePointNet(nn.Module, metaclass=ABCMeta):
    """Base class for PointNet."""

    def __init__(self):
        super(BasePointNet, self).__init__()
11
        self.fp16_enabled = False
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    def init_weights(self, pretrained=None):
        """Initialize the weights of PointNet backbone."""
        # Do not initialize the conv layers
        # to follow the original implementation
        if isinstance(pretrained, str):
            from mmdet3d.utils import get_root_logger
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

    @staticmethod
    def _split_point_feats(points):
        """Split coordinates and features of input points.

        Args:
            points (torch.Tensor): Point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            torch.Tensor: Coordinates of input points.
            torch.Tensor: Features of input points.
        """
        xyz = points[..., 0:3].contiguous()
        if points.size(-1) > 3:
            features = points[..., 3:].transpose(1, 2).contiguous()
        else:
            features = None

        return xyz, features