Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
8f75dd3b
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "bf843c664b8ba0ff49d2921237500c77d82f2d04"
Commit
8f75dd3b
authored
Apr 27, 2020
by
wuyuefeng
Browse files
add docstring for roiaware pool3d
parent
7a872356
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
8 deletions
+18
-8
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+6
-2
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
+12
-6
No files found.
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
View file @
8f75dd3b
...
@@ -4,12 +4,14 @@ from . import roiaware_pool3d_ext
...
@@ -4,12 +4,14 @@ from . import roiaware_pool3d_ext
def
points_in_boxes_gpu
(
points
,
boxes
):
def
points_in_boxes_gpu
(
points
,
boxes
):
"""
"""find points in boxes (CUDA)
Args:
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
(x, y, z) is the bottom center
Returns:
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
"""
...
@@ -27,11 +29,13 @@ def points_in_boxes_gpu(points, boxes):
...
@@ -27,11 +29,13 @@ def points_in_boxes_gpu(points, boxes):
def
points_in_boxes_cpu
(
points
,
boxes
):
def
points_in_boxes_cpu
(
points
,
boxes
):
"""
"""find points in boxes (CPU)
Args:
Args:
points (torch.Tensor): [npoints, 3]
points (torch.Tensor): [npoints, 3]
boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center
(x, y, z) is the bottom center
Returns:
Returns:
point_indices (torch.Tensor): (N, npoints)
point_indices (torch.Tensor): (N, npoints)
"""
"""
...
...
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
View file @
8f75dd3b
...
@@ -10,7 +10,8 @@ class RoIAwarePool3d(nn.Module):
...
@@ -10,7 +10,8 @@ class RoIAwarePool3d(nn.Module):
def
__init__
(
self
,
out_size
,
max_pts_per_voxel
=
128
,
mode
=
'max'
):
def
__init__
(
self
,
out_size
,
max_pts_per_voxel
=
128
,
mode
=
'max'
):
super
().
__init__
()
super
().
__init__
()
"""
"""RoIAwarePool3d module
Args:
Args:
out_size (int or tuple): n or [n1, n2, n3]
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
max_pts_per_voxel (int): m
...
@@ -23,12 +24,14 @@ class RoIAwarePool3d(nn.Module):
...
@@ -23,12 +24,14 @@ class RoIAwarePool3d(nn.Module):
self
.
mode
=
pool_method_map
[
mode
]
self
.
mode
=
pool_method_map
[
mode
]
def
forward
(
self
,
rois
,
pts
,
pts_feature
):
def
forward
(
self
,
rois
,
pts
,
pts_feature
):
"""
"""RoIAwarePool3d module forward
Args:
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
(x, y, z) is the bottom center of rois
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
pts_feature (torch.Tensor): [npoints, C]
Returns:
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
"""
...
@@ -43,7 +46,8 @@ class RoIAwarePool3dFunction(Function):
...
@@ -43,7 +46,8 @@ class RoIAwarePool3dFunction(Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
rois
,
pts
,
pts_feature
,
out_size
,
max_pts_per_voxel
,
def
forward
(
ctx
,
rois
,
pts
,
pts_feature
,
out_size
,
max_pts_per_voxel
,
mode
):
mode
):
"""
"""RoIAwarePool3d function forward
Args:
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois
(x, y, z) is the bottom center of rois
...
@@ -52,6 +56,7 @@ class RoIAwarePool3dFunction(Function):
...
@@ -52,6 +56,7 @@ class RoIAwarePool3dFunction(Function):
out_size (int or tuple): n or [n1, n2, n3]
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
max_pts_per_voxel (int): m
mode (int): 0 (max pool) or 1 (average pool)
mode (int): 0 (max pool) or 1 (average pool)
Returns:
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
"""
...
@@ -84,11 +89,12 @@ class RoIAwarePool3dFunction(Function):
...
@@ -84,11 +89,12 @@ class RoIAwarePool3dFunction(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
"""
"""RoIAwarePool3d function forward
Args:
Args:
grad_out: [N, out_x, out_y, out_z, C]
grad_out
(torch.Tensor)
: [N, out_x, out_y, out_z, C]
Returns:
Returns:
grad_in: [npoints, C]
grad_in
(torch.Tensor)
: [npoints, C]
"""
"""
ret
=
ctx
.
roiaware_pool3d_for_backward
ret
=
ctx
.
roiaware_pool3d_for_backward
pts_idx_of_voxels
,
argmax
,
mode
,
num_pts
,
num_channels
=
ret
pts_idx_of_voxels
,
argmax
,
mode
,
num_pts
,
num_channels
=
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment