base_pointnet.py 1.26 KB
Newer Older
1
import warnings
2
from abc import ABCMeta
3
from mmcv.runner import BaseModule
4
5


6
class BasePointNet(BaseModule, metaclass=ABCMeta):
7
8
    """Base class for PointNet."""

9
10
    def __init__(self, init_cfg=None, pretrained=None):
        super(BasePointNet, self).__init__(init_cfg)
11
        self.fp16_enabled = False
12
13
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
14
        if isinstance(pretrained, str):
15
16
17
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

    @staticmethod
    def _split_point_feats(points):
        """Split coordinates and features of input points.

        Args:
            points (torch.Tensor): Point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            torch.Tensor: Coordinates of input points.
            torch.Tensor: Features of input points.
        """
        xyz = points[..., 0:3].contiguous()
        if points.size(-1) > 3:
            features = points[..., 3:].transpose(1, 2).contiguous()
        else:
            features = None

        return xyz, features