"examples/pipeline_tacotron2/text/__init__.py" did not exist on "d3c83eaa680dead87a121552c0f9a323c4526620"
roiaware_pool3d.py 3.58 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
import mmcv
import torch
zhangwenwei's avatar
zhangwenwei committed
3
from torch import nn as nn
wuyuefeng's avatar
wuyuefeng committed
4
5
6
7
8
9
10
11
12
from torch.autograd import Function

from . import roiaware_pool3d_ext


class RoIAwarePool3d(nn.Module):

    def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
        super().__init__()
wuyuefeng's avatar
wuyuefeng committed
13
14
        """RoIAwarePool3d module

wuyuefeng's avatar
wuyuefeng committed
15
16
17
18
19
20
21
22
23
24
25
26
        Args:
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (str): 'max' or 'avg'
        """
        self.out_size = out_size
        self.max_pts_per_voxel = max_pts_per_voxel
        assert mode in ['max', 'avg']
        pool_method_map = {'max': 0, 'avg': 1}
        self.mode = pool_method_map[mode]

    def forward(self, rois, pts, pts_feature):
zhangwenwei's avatar
zhangwenwei committed
27
        """RoIAwarePool3d module forward.
wuyuefeng's avatar
wuyuefeng committed
28

wuyuefeng's avatar
wuyuefeng committed
29
30
31
32
33
        Args:
            rois (torch.Tensor): [N, 7],in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
34

wuyuefeng's avatar
wuyuefeng committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
                                            self.out_size,
                                            self.max_pts_per_voxel, self.mode)


class RoIAwarePool3dFunction(Function):

    @staticmethod
    def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
                mode):
zhangwenwei's avatar
zhangwenwei committed
49
        """RoIAwarePool3d function forward.
wuyuefeng's avatar
wuyuefeng committed
50

wuyuefeng's avatar
wuyuefeng committed
51
52
53
54
55
56
57
58
        Args:
            rois (torch.Tensor): [N, 7], in LiDAR coordinate,
                (x, y, z) is the bottom center of rois
            pts (torch.Tensor): [npoints, 3]
            pts_feature (torch.Tensor): [npoints, C]
            out_size (int or tuple): n or [n1, n2, n3]
            max_pts_per_voxel (int): m
            mode (int): 0 (max pool) or 1 (average pool)
wuyuefeng's avatar
wuyuefeng committed
59

wuyuefeng's avatar
wuyuefeng committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        Returns:
            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
        """

        if isinstance(out_size, int):
            out_x = out_y = out_z = out_size
        else:
            assert len(out_size) == 3
            assert mmcv.is_tuple_of(out_size, int)
            out_x, out_y, out_z = out_size

        num_rois = rois.shape[0]
        num_channels = pts_feature.shape[-1]
        num_pts = pts.shape[0]

        pooled_features = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels))
        argmax = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
        pts_idx_of_voxels = pts_feature.new_zeros(
            (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
            dtype=torch.int)

        roiaware_pool3d_ext.forward(rois, pts, pts_feature, argmax,
                                    pts_idx_of_voxels, pooled_features, mode)

        ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
                                            num_pts, num_channels)
        return pooled_features

    @staticmethod
    def backward(ctx, grad_out):
zhangwenwei's avatar
zhangwenwei committed
92
        """RoIAwarePool3d function forward.
wuyuefeng's avatar
wuyuefeng committed
93

wuyuefeng's avatar
wuyuefeng committed
94
        Args:
wuyuefeng's avatar
wuyuefeng committed
95
            grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
wuyuefeng's avatar
wuyuefeng committed
96
        Returns:
wuyuefeng's avatar
wuyuefeng committed
97
            grad_in (torch.Tensor): [npoints, C]
wuyuefeng's avatar
wuyuefeng committed
98
99
100
101
102
103
104
105
106
107
108
109
110
        """
        ret = ctx.roiaware_pool3d_for_backward
        pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret

        grad_in = grad_out.new_zeros((num_pts, num_channels))
        roiaware_pool3d_ext.backward(pts_idx_of_voxels, argmax,
                                     grad_out.contiguous(), grad_in, mode)

        return None, None, grad_in, None, None, None


if __name__ == '__main__':
    pass