Commit 8f75dd3b authored by wuyuefeng's avatar wuyuefeng
Browse files

add docstring for roiaware pool3d

parent 7a872356
......@@ -4,12 +4,14 @@ from . import roiaware_pool3d_ext
def points_in_boxes_gpu(points, boxes):
"""
"""find points in boxes (CUDA)
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
......@@ -27,11 +29,13 @@ def points_in_boxes_gpu(points, boxes):
def points_in_boxes_cpu(points, boxes):
"""
"""find points in boxes (CPU)
Args:
points (torch.Tensor): [npoints, 3]
boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
point_indices (torch.Tensor): (N, npoints)
"""
......
......@@ -10,7 +10,8 @@ class RoIAwarePool3d(nn.Module):
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
super().__init__()
"""
"""RoIAwarePool3d module
Args:
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
......@@ -23,12 +24,14 @@ class RoIAwarePool3d(nn.Module):
self.mode = pool_method_map[mode]
def forward(self, rois, pts, pts_feature):
"""
"""RoIAwarePool3d module forward
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
......@@ -43,7 +46,8 @@ class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
"""
"""RoIAwarePool3d function forward
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois
......@@ -52,6 +56,7 @@ class RoIAwarePool3dFunction(Function):
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
mode (int): 0 (max pool) or 1 (average pool)
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
......@@ -84,11 +89,12 @@ class RoIAwarePool3dFunction(Function):
@staticmethod
def backward(ctx, grad_out):
"""
"""RoIAwarePool3d function forward
Args:
grad_out: [N, out_x, out_y, out_z, C]
grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
Returns:
grad_in: [npoints, C]
grad_in (torch.Tensor): [npoints, C]
"""
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment