base_pointnet.py 1.48 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
from abc import ABCMeta
4
from typing import Optional, Tuple
5

6
from mmengine.model import BaseModule
7
8
9
from torch import Tensor

from mmdet3d.utils import OptMultiConfig
10
11


12
class BasePointNet(BaseModule, metaclass=ABCMeta):
13
14
    """Base class for PointNet."""

15
16
17
    def __init__(self,
                 init_cfg: OptMultiConfig = None,
                 pretrained: Optional[str] = None):
18
19
20
        super(BasePointNet, self).__init__(init_cfg)
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
21
        if isinstance(pretrained, str):
22
23
24
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
25
26

    @staticmethod
27
    def _split_point_feats(points: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        """Split coordinates and features of input points.

        Args:
            points (torch.Tensor): Point coordinates with features,
                with shape (B, N, 3 + input_feature_dim).

        Returns:
            torch.Tensor: Coordinates of input points.
            torch.Tensor: Features of input points.
        """
        xyz = points[..., 0:3].contiguous()
        if points.size(-1) > 3:
            features = points[..., 3:].transpose(1, 2).contiguous()
        else:
            features = None

        return xyz, features