roipoint_pool3d.py 2.45 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torch import nn as nn
from torch.autograd import Function

from . import roipoint_pool3d_ext


class RoIPointPool3d(nn.Module):

    def __init__(self, num_sampled_points=512):
        super().__init__()
        """
        Args:
            num_sampled_points (int): Number of samples in each roi
        """
        self.num_sampled_points = num_sampled_points

    def forward(self, points, point_features, boxes3d):
        """
        Args:
            points (torch.Tensor): Input points whose shape is BxNx3
            point_features: (B, N, C)
            boxes3d: (B, M, 7), [x, y, z, dx, dy, dz, heading]

        Returns:
            torch.Tensor: (B, M, 512, 3 + C) pooled_features
            torch.Tensor: (B, M) pooled_empty_flag
        """
        return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
                                            self.num_sampled_points)


class RoIPointPool3dFunction(Function):

    @staticmethod
    def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
        """
        Args:
            points (torch.Tensor): Input points whose shape is (B, N, 3)
            point_features (torch.Tensor): Input points features shape is \
                (B, N, C)
            boxes3d (torch.Tensor): Input bounding boxes whose shape is \
                (B, M, 7)
            num_sampled_points (int): the num of sampled points

        Returns:
            torch.Tensor: (B, M, 512, 3 + C) pooled_features
            torch.Tensor: (B, M) pooled_empty_flag
        """
        assert points.shape.__len__() == 3 and points.shape[2] == 3
        batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
            1], point_features.shape[2]
        pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
        pooled_features = point_features.new_zeros(
            (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
        pooled_empty_flag = point_features.new_zeros(
            (batch_size, boxes_num)).int()

        roipoint_pool3d_ext.forward(points.contiguous(),
                                    pooled_boxes3d.contiguous(),
                                    point_features.contiguous(),
                                    pooled_features, pooled_empty_flag)

        return pooled_features, pooled_empty_flag

    @staticmethod
    def backward(ctx, grad_out):
        raise NotImplementedError


if __name__ == '__main__':
    pass