roiaware_pool3d.py 3.57 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
7
8
9
10
11
12
import mmcv
import torch
import torch.nn as nn
from torch.autograd import Function

from . import roiaware_pool3d_ext


class RoIAwarePool3d(nn.Module):

    def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
        super().__init__()
wuyuefeng's avatar
wuyuefeng committed
13
14
        """RoIAwarePool3d module

wuyuefeng's avatar
wuyuefeng committed
15
16
17
18
19
20
21
22
23
24
25
26
        Args:
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (str): 'max' or 'avg'
        """
        self.out_size = out_size
        self.max_pts_per_voxel = max_pts_per_voxel
        assert mode in ['max', 'avg']
        pool_method_map = {'max': 0, 'avg': 1}
        self.mode = pool_method_map[mode]

    def forward(self, rois, pts, pts_feature):
wuyuefeng's avatar
wuyuefeng committed
27
28
        """RoIAwarePool3d module forward

wuyuefeng's avatar
wuyuefeng committed
29
30
31
32
33
        Args:
            rois (torch.Tensor): [N, 7],in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
34

wuyuefeng's avatar
wuyuefeng committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
                                            self.out_size,
                                            self.max_pts_per_voxel, self.mode)


class RoIAwarePool3dFunction(Function):

    @staticmethod
    def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
                mode):
wuyuefeng's avatar
wuyuefeng committed
49
50
        """RoIAwarePool3d function forward

wuyuefeng's avatar
wuyuefeng committed
51
52
53
54
55
56
57
58
        Args:
            rois (torch.Tensor): [N, 7], in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (int): 0 (max pool) or 1 (average pool)
wuyuefeng's avatar
wuyuefeng committed
59

wuyuefeng's avatar
wuyuefeng committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        if isinstance(out_size, int):
            out_x = out_y = out_z = out_size
        else:
            assert len(out_size) == 3
            assert mmcv.is_tuple_of(out_size, int)
            out_x, out_y, out_z = out_size

        num_rois = rois.shape[0]
        num_channels = pts_feature.shape[-1]
        num_pts = pts.shape[0]

        pooled_features = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels))
        argmax = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
        pts_idx_of_voxels = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
            dtype=torch.int)

        roiaware_pool3d_ext.forward(rois, pts, pts_feature, argmax,
                                    pts_idx_of_voxels, pooled_features, mode)

        ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
                                            num_pts, num_channels)
        return pooled_features

    @staticmethod
    def backward(ctx, grad_out):
wuyuefeng's avatar
wuyuefeng committed
92
93
        """RoIAwarePool3d function forward

wuyuefeng's avatar
wuyuefeng committed
94
        Args:
wuyuefeng's avatar
wuyuefeng committed
95
            grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
wuyuefeng's avatar
wuyuefeng committed
96
        Returns:
wuyuefeng's avatar
wuyuefeng committed
97
            grad_in (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
98
99
100
101
102
103
104
105
106
107
108
109
110
        """
        ret = ctx.roiaware_pool3d_for_backward
        pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret

        grad_in = grad_out.new_zeros((num_pts, num_channels))
        roiaware_pool3d_ext.backward(pts_idx_of_voxels, argmax,
                                     grad_out.contiguous(), grad_in, mode)

        return None, None, grad_in, None, None, None


if __name__ == '__main__':
    pass