import warnings from abc import ABCMeta from mmcv.runner import BaseModule class BasePointNet(BaseModule, metaclass=ABCMeta): """Base class for PointNet.""" def __init__(self, init_cfg=None, pretrained=None): super(BasePointNet, self).__init__(init_cfg) self.fp16_enabled = False assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be setting at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) @staticmethod def _split_point_feats(points): """Split coordinates and features of input points. Args: points (torch.Tensor): Point coordinates with features, with shape (B, N, 3 + input_feature_dim). Returns: torch.Tensor: Coordinates of input points. torch.Tensor: Features of input points. """ xyz = points[..., 0:3].contiguous() if points.size(-1) > 3: features = points[..., 3:].transpose(1, 2).contiguous() else: features = None return xyz, features