Commit 8f75dd3b authored by wuyuefeng's avatar wuyuefeng
Browse files

add docstring for roiaware pool3d

parent 7a872356
...@@ -4,12 +4,14 @@ from . import roiaware_pool3d_ext ...@@ -4,12 +4,14 @@ from . import roiaware_pool3d_ext
def points_in_boxes_gpu(points, boxes): def points_in_boxes_gpu(points, boxes):
""" """find points in boxes (CUDA)
Args: Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7], boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate, num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center (x, y, z) is the bottom center
Returns: Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1 box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
""" """
...@@ -27,11 +29,13 @@ def points_in_boxes_gpu(points, boxes): ...@@ -27,11 +29,13 @@ def points_in_boxes_gpu(points, boxes):
def points_in_boxes_cpu(points, boxes): def points_in_boxes_cpu(points, boxes):
""" """find points in boxes (CPU)
Args: Args:
points (torch.Tensor): [npoints, 3] points (torch.Tensor): [npoints, 3]
boxes (torch.Tensor): [N, 7], in LiDAR coordinate, boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center (x, y, z) is the bottom center
Returns: Returns:
point_indices (torch.Tensor): (N, npoints) point_indices (torch.Tensor): (N, npoints)
""" """
......
...@@ -10,7 +10,8 @@ class RoIAwarePool3d(nn.Module): ...@@ -10,7 +10,8 @@ class RoIAwarePool3d(nn.Module):
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
super().__init__() super().__init__()
""" """RoIAwarePool3d module
Args: Args:
out_size (int or tuple): n or [n1, n2, n3] out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m max_pts_per_voxel (int): m
...@@ -23,12 +24,14 @@ class RoIAwarePool3d(nn.Module): ...@@ -23,12 +24,14 @@ class RoIAwarePool3d(nn.Module):
self.mode = pool_method_map[mode] self.mode = pool_method_map[mode]
def forward(self, rois, pts, pts_feature): def forward(self, rois, pts, pts_feature):
""" """RoIAwarePool3d module forward
Args: Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate, rois (torch.Tensor): [N, 7],in LiDAR coordinate,
(x, y, z) is the bottom center of rois (x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3] pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C] pts_feature (torch.Tensor): [npoints, C]
Returns: Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
""" """
...@@ -43,7 +46,8 @@ class RoIAwarePool3dFunction(Function): ...@@ -43,7 +46,8 @@ class RoIAwarePool3dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode): mode):
""" """RoIAwarePool3d function forward
Args: Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate, rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois (x, y, z) is the bottom center of rois
...@@ -52,6 +56,7 @@ class RoIAwarePool3dFunction(Function): ...@@ -52,6 +56,7 @@ class RoIAwarePool3dFunction(Function):
out_size (int or tuple): n or [n1, n2, n3] out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m max_pts_per_voxel (int): m
mode (int): 0 (max pool) or 1 (average pool) mode (int): 0 (max pool) or 1 (average pool)
Returns: Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
""" """
...@@ -84,11 +89,12 @@ class RoIAwarePool3dFunction(Function): ...@@ -84,11 +89,12 @@ class RoIAwarePool3dFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
""" """RoIAwarePool3d function forward
Args: Args:
grad_out: [N, out_x, out_y, out_z, C] grad_out (torch.Tensor): [N, out_x, out_y, out_z, C]
Returns: Returns:
grad_in: [npoints, C] grad_in (torch.Tensor): [npoints, C]
""" """
ret = ctx.roiaware_pool3d_for_backward ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment