roiaware_pool3d.py 3.63 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
3
import mmcv
import torch
zhangwenwei's avatar
zhangwenwei committed
4
from torch import nn as nn
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8
9
10
11
12
13
from torch.autograd import Function

from . import roiaware_pool3d_ext


class RoIAwarePool3d(nn.Module):

    def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
        super().__init__()
wuyuefeng's avatar
wuyuefeng committed
14
15
        """RoIAwarePool3d module

wuyuefeng's avatar
wuyuefeng committed
16
17
18
19
20
21
22
23
24
25
26
27
        Args:
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (str): 'max' or 'avg'
        """
        self.out_size = out_size
        self.max_pts_per_voxel = max_pts_per_voxel
        assert mode in ['max', 'avg']
        pool_method_map = {'max': 0, 'avg': 1}
        self.mode = pool_method_map[mode]

    def forward(self, rois, pts, pts_feature):
zhangwenwei's avatar
zhangwenwei committed
28
        """RoIAwarePool3d module forward.
wuyuefeng's avatar
wuyuefeng committed
29

wuyuefeng's avatar
wuyuefeng committed
30
31
32
33
34
        Args:
            rois (torch.Tensor): [N, 7],in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
35

wuyuefeng's avatar
wuyuefeng committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
                                            self.out_size,
                                            self.max_pts_per_voxel, self.mode)


class RoIAwarePool3dFunction(Function):

    @staticmethod
    def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
                mode):
zhangwenwei's avatar
zhangwenwei committed
50
        """RoIAwarePool3d function forward.
wuyuefeng's avatar
wuyuefeng committed
51

wuyuefeng's avatar
wuyuefeng committed
52
53
54
55
56
57
58
59
        Args:
            rois (torch.Tensor): [N, 7], in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (int): 0 (max pool) or 1 (average pool)
wuyuefeng's avatar
wuyuefeng committed
60

wuyuefeng's avatar
wuyuefeng committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        if isinstance(out_size, int):
            out_x = out_y = out_z = out_size
        else:
            assert len(out_size) == 3
            assert mmcv.is_tuple_of(out_size, int)
            out_x, out_y, out_z = out_size

        num_rois = rois.shape[0]
        num_channels = pts_feature.shape[-1]
        num_pts = pts.shape[0]

        pooled_features = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels))
        argmax = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
        pts_idx_of_voxels = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
            dtype=torch.int)

        roiaware_pool3d_ext.forward(rois, pts, pts_feature, argmax,
                                    pts_idx_of_voxels, pooled_features, mode)

        ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
                                            num_pts, num_channels)
        return pooled_features

    @staticmethod
    def backward(ctx, grad_out):
zhangwenwei's avatar
zhangwenwei committed
93
        """RoIAwarePool3d function forward.
wuyuefeng's avatar
wuyuefeng committed
94

wuyuefeng's avatar
wuyuefeng committed
95
        Args:
wuyuefeng's avatar
wuyuefeng committed
96
            grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
wuyuefeng's avatar
wuyuefeng committed
97
        Returns:
wuyuefeng's avatar
wuyuefeng committed
98
            grad_in (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
99
100
101
102
103
104
105
106
107
108
109
110
111
        """
        ret = ctx.roiaware_pool3d_for_backward
        pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret

        grad_in = grad_out.new_zeros((num_pts, num_channels))
        roiaware_pool3d_ext.backward(pts_idx_of_voxels, argmax,
                                     grad_out.contiguous(), grad_in, mode)

        return None, None, grad_in, None, None, None


if __name__ == '__main__':
    pass