Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
dec83ff9
Commit
dec83ff9
authored
Apr 26, 2020
by
wuyuefeng
Committed by
zhangwenwei
Apr 26, 2020
Browse files
add ops about roiaware_pool3d
parent
6d71b439
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
900 additions
and
1 deletion
+900
-1
.isort.cfg
.isort.cfg
+1
-1
mmdet3d/ops/roiaware_pool3d/__init__.py
mmdet3d/ops/roiaware_pool3d/__init__.py
+4
-0
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+47
-0
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
+104
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cpu.cpp
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cpu.cpp
+68
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
+116
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
+109
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
+323
-0
setup.py
setup.py
+11
-0
tests/test_roiaware_pool3d.py
tests/test_roiaware_pool3d.py
+117
-0
No files found.
.isort.cfg
View file @
dec83ff9
...
@@ -3,6 +3,6 @@ line_length = 79
...
@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0
multi_line_output = 0
known_standard_library = setuptools
known_standard_library = setuptools
known_first_party = mmdet,mmdet3d
known_first_party = mmdet,mmdet3d
known_third_party = cv2,mmcv,numba,numpy,nuscenes,pycocotools,pyquaternion,shapely,six,skimage,torch,torchvision
known_third_party = cv2,mmcv,numba,numpy,nuscenes,pycocotools,pyquaternion,
pytest,
shapely,six,skimage,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
default_section = THIRDPARTY
mmdet3d/ops/roiaware_pool3d/__init__.py
0 → 100644
View file @
dec83ff9
from
.points_in_boxes
import
points_in_boxes_cpu
,
points_in_boxes_gpu
from
.roiaware_pool3d
import
RoIAwarePool3d
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
]
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
0 → 100644
View file @
dec83ff9
import
torch
from
.
import
roiaware_pool3d_ext
def
points_in_boxes_gpu
(
points
,
boxes
):
"""
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
assert
boxes
.
shape
[
0
]
==
points
.
shape
[
0
]
assert
boxes
.
shape
[
2
]
==
7
batch_size
,
num_points
,
_
=
points
.
shape
box_idxs_of_pts
=
points
.
new_zeros
((
batch_size
,
num_points
),
dtype
=
torch
.
int
).
fill_
(
-
1
)
roiaware_pool3d_ext
.
points_in_boxes_gpu
(
boxes
.
contiguous
(),
points
.
contiguous
(),
box_idxs_of_pts
)
return
box_idxs_of_pts
def
points_in_boxes_cpu
(
points
,
boxes
):
"""
Args:
points (torch.Tensor): [npoints, 3]
boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
point_indices (torch.Tensor): (N, npoints)
"""
assert
boxes
.
shape
[
1
]
==
7
assert
points
.
shape
[
1
]
==
3
point_indices
=
points
.
new_zeros
((
boxes
.
shape
[
0
],
points
.
shape
[
0
]),
dtype
=
torch
.
int
)
roiaware_pool3d_ext
.
points_in_boxes_cpu
(
boxes
.
float
().
contiguous
(),
points
.
float
().
contiguous
(),
point_indices
)
return
point_indices
mmdet3d/ops/roiaware_pool3d/roiaware_pool3d.py
0 → 100644
View file @
dec83ff9
import
mmcv
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
.
import
roiaware_pool3d_ext
class
RoIAwarePool3d
(
nn
.
Module
):
def
__init__
(
self
,
out_size
,
max_pts_per_voxel
=
128
,
mode
=
'max'
):
super
().
__init__
()
"""
Args:
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
mode (str): 'max' or 'avg'
"""
self
.
out_size
=
out_size
self
.
max_pts_per_voxel
=
max_pts_per_voxel
assert
mode
in
[
'max'
,
'avg'
]
pool_method_map
=
{
'max'
:
0
,
'avg'
:
1
}
self
.
mode
=
pool_method_map
[
mode
]
def
forward
(
self
,
rois
,
pts
,
pts_feature
):
"""
Args:
rois (torch.Tensor): [N, 7],in LiDAR coordinate,
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
return
RoIAwarePool3dFunction
.
apply
(
rois
,
pts
,
pts_feature
,
self
.
out_size
,
self
.
max_pts_per_voxel
,
self
.
mode
)
class
RoIAwarePool3dFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
rois
,
pts
,
pts_feature
,
out_size
,
max_pts_per_voxel
,
mode
):
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois
pts (torch.Tensor): [npoints, 3]
pts_feature (torch.Tensor): [npoints, C]
out_size (int or tuple): n or [n1, n2, n3]
max_pts_per_voxel (int): m
mode (int): 0 (max pool) or 1 (average pool)
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
if
isinstance
(
out_size
,
int
):
out_x
=
out_y
=
out_z
=
out_size
else
:
assert
len
(
out_size
)
==
3
assert
mmcv
.
is_tuple_of
(
out_size
,
int
)
out_x
,
out_y
,
out_z
=
out_size
num_rois
=
rois
.
shape
[
0
]
num_channels
=
pts_feature
.
shape
[
-
1
]
num_pts
=
pts
.
shape
[
0
]
pooled_features
=
pts_feature
.
new_zeros
(
(
num_rois
,
out_x
,
out_y
,
out_z
,
num_channels
))
argmax
=
pts_feature
.
new_zeros
(
(
num_rois
,
out_x
,
out_y
,
out_z
,
num_channels
),
dtype
=
torch
.
int
)
pts_idx_of_voxels
=
pts_feature
.
new_zeros
(
(
num_rois
,
out_x
,
out_y
,
out_z
,
max_pts_per_voxel
),
dtype
=
torch
.
int
)
roiaware_pool3d_ext
.
forward
(
rois
,
pts
,
pts_feature
,
argmax
,
pts_idx_of_voxels
,
pooled_features
,
mode
)
ctx
.
roiaware_pool3d_for_backward
=
(
pts_idx_of_voxels
,
argmax
,
mode
,
num_pts
,
num_channels
)
return
pooled_features
@
staticmethod
def
backward
(
ctx
,
grad_out
):
"""
Args:
grad_out: [N, out_x, out_y, out_z, C]
Returns:
grad_in: [npoints, C]
"""
ret
=
ctx
.
roiaware_pool3d_for_backward
pts_idx_of_voxels
,
argmax
,
mode
,
num_pts
,
num_channels
=
ret
grad_in
=
grad_out
.
new_zeros
((
num_pts
,
num_channels
))
roiaware_pool3d_ext
.
backward
(
pts_idx_of_voxels
,
argmax
,
grad_out
.
contiguous
(),
grad_in
,
mode
)
return
None
,
None
,
grad_in
,
None
,
None
,
None
if
__name__
==
'__main__'
:
pass
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cpu.cpp
0 → 100644
View file @
dec83ff9
//Modified from
//https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
//Points in boxes cpu
//Written by Shaoshuai Shi
//All Rights Reserved 2019.
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <assert.h>
#include <math.h>
#include <stdio.h>
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
// #define DEBUG
inline
void
lidar_to_local_coords_cpu
(
float
shift_x
,
float
shift_y
,
float
rz
,
float
&
local_x
,
float
&
local_y
){
// should rotate pi/2 + alpha to translate LiDAR to local
float
rot_angle
=
rz
+
M_PI
/
2
;
float
cosa
=
cos
(
rot_angle
),
sina
=
sin
(
rot_angle
);
local_x
=
shift_x
*
cosa
+
shift_y
*
(
-
sina
);
local_y
=
shift_x
*
sina
+
shift_y
*
cosa
;
}
inline
int
check_pt_in_box3d_cpu
(
const
float
*
pt
,
const
float
*
box3d
,
float
&
local_x
,
float
&
local_y
){
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, w, l, h, rz) in LiDAR coordinate, cz in the bottom center
float
x
=
pt
[
0
],
y
=
pt
[
1
],
z
=
pt
[
2
];
float
cx
=
box3d
[
0
],
cy
=
box3d
[
1
],
cz
=
box3d
[
2
];
float
w
=
box3d
[
3
],
l
=
box3d
[
4
],
h
=
box3d
[
5
],
rz
=
box3d
[
6
];
cz
+=
h
/
2.0
;
// shift to the center since cz in box3d is the bottom center
if
(
fabsf
(
z
-
cz
)
>
h
/
2.0
)
return
0
;
lidar_to_local_coords_cpu
(
x
-
cx
,
y
-
cy
,
rz
,
local_x
,
local_y
);
float
in_flag
=
(
local_x
>
-
l
/
2.0
)
&
(
local_x
<
l
/
2.0
)
&
(
local_y
>
-
w
/
2.0
)
&
(
local_y
<
w
/
2.0
);
return
in_flag
;
}
int
points_in_boxes_cpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
pts_indices_tensor
){
// params boxes: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is the bottom center, each box DO NOT overlaps
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_indices: (N, npoints)
CHECK_CONTIGUOUS
(
boxes_tensor
);
CHECK_CONTIGUOUS
(
pts_tensor
);
CHECK_CONTIGUOUS
(
pts_indices_tensor
);
int
boxes_num
=
boxes_tensor
.
size
(
0
);
int
pts_num
=
pts_tensor
.
size
(
0
);
const
float
*
boxes
=
boxes_tensor
.
data
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data
<
float
>
();
int
*
pts_indices
=
pts_indices_tensor
.
data
<
int
>
();
float
local_x
=
0
,
local_y
=
0
;
for
(
int
i
=
0
;
i
<
boxes_num
;
i
++
){
for
(
int
j
=
0
;
j
<
pts_num
;
j
++
){
int
cur_in_flag
=
check_pt_in_box3d_cpu
(
pts
+
j
*
3
,
boxes
+
i
*
7
,
local_x
,
local_y
);
pts_indices
[
i
*
pts_num
+
j
]
=
cur_in_flag
;
}
}
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
0 → 100644
View file @
dec83ff9
//Modified from
//https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
//Points in boxes gpu
//Written by Shaoshuai Shi
//All Rights Reserved 2019.
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <assert.h>
#include <math.h>
#include <stdio.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
// #define DEBUG
__device__
inline
void
lidar_to_local_coords
(
float
shift_x
,
float
shift_y
,
float
rz
,
float
&
local_x
,
float
&
local_y
){
// should rotate pi/2 + alpha to translate LiDAR to local
float
rot_angle
=
rz
+
M_PI
/
2
;
float
cosa
=
cos
(
rot_angle
),
sina
=
sin
(
rot_angle
);
local_x
=
shift_x
*
cosa
+
shift_y
*
(
-
sina
);
local_y
=
shift_x
*
sina
+
shift_y
*
cosa
;
}
__device__
inline
int
check_pt_in_box3d
(
const
float
*
pt
,
const
float
*
box3d
,
float
&
local_x
,
float
&
local_y
){
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, w, l, h, rz) in LiDAR coordinate, cz in the bottom center
float
x
=
pt
[
0
],
y
=
pt
[
1
],
z
=
pt
[
2
];
float
cx
=
box3d
[
0
],
cy
=
box3d
[
1
],
cz
=
box3d
[
2
];
float
w
=
box3d
[
3
],
l
=
box3d
[
4
],
h
=
box3d
[
5
],
rz
=
box3d
[
6
];
cz
+=
h
/
2.0
;
// shift to the center since cz in box3d is the bottom center
if
(
fabsf
(
z
-
cz
)
>
h
/
2.0
)
return
0
;
lidar_to_local_coords
(
x
-
cx
,
y
-
cy
,
rz
,
local_x
,
local_y
);
float
in_flag
=
(
local_x
>
-
l
/
2.0
)
&
(
local_x
<
l
/
2.0
)
&
(
local_y
>
-
w
/
2.0
)
&
(
local_y
<
w
/
2.0
);
return
in_flag
;
}
__global__
void
points_in_boxes_kernel
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
){
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is the bottom center, each box DO NOT overlaps
// params pts: (B, npoints, 3) [x, y, z] in LiDAR coordinate
// params boxes_idx_of_points: (B, npoints), default -1
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
batch_size
||
pt_idx
>=
pts_num
)
return
;
boxes
+=
bs_idx
*
boxes_num
*
7
;
pts
+=
bs_idx
*
pts_num
*
3
+
pt_idx
*
3
;
box_idx_of_points
+=
bs_idx
*
pts_num
+
pt_idx
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
0
;
for
(
int
k
=
0
;
k
<
boxes_num
;
k
++
){
cur_in_flag
=
check_pt_in_box3d
(
pts
,
boxes
+
k
*
7
,
local_x
,
local_y
);
if
(
cur_in_flag
){
box_idx_of_points
[
0
]
=
k
;
break
;
}
}
}
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
){
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is the bottom center, each box DO NOT overlaps
// params pts: (B, npoints, 3) [x, y, z] in LiDAR coordinate
// params boxes_idx_of_points: (B, npoints), default -1
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
batch_size
);
dim3
threads
(
THREADS_PER_BLOCK
);
points_in_boxes_kernel
<<<
blocks
,
threads
>>>
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
){
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is the bottom center, each box DO NOT overlaps
// params pts: (B, npoints, 3) [x, y, z] in LiDAR coordinate
// params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT
(
boxes_tensor
);
CHECK_INPUT
(
pts_tensor
);
CHECK_INPUT
(
box_idx_of_points_tensor
);
int
batch_size
=
boxes_tensor
.
size
(
0
);
int
boxes_num
=
boxes_tensor
.
size
(
1
);
int
pts_num
=
pts_tensor
.
size
(
1
);
const
float
*
boxes
=
boxes_tensor
.
data
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data
<
float
>
();
int
*
box_idx_of_points
=
box_idx_of_points_tensor
.
data
<
int
>
();
points_in_boxes_launcher
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
0 → 100644
View file @
dec83ff9
//Modified from
//https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
//RoI-aware point cloud feature pooling
//Written by Shaoshuai Shi
//All Rights Reserved 2019.
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <assert.h>
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void
roiaware_pool3d_launcher
(
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
float
*
rois
,
const
float
*
pts
,
const
float
*
pts_feature
,
int
*
argmax
,
int
*
pts_idx_of_voxels
,
float
*
pooled_features
,
int
pool_method
);
void
roiaware_pool3d_backward_launcher
(
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
channels
,
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
int
*
argmax
,
const
float
*
grad_out
,
float
*
grad_in
,
int
pool_method
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
);
int
roiaware_pool3d_gpu_backward
(
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
argmax
,
at
::
Tensor
grad_out
,
at
::
Tensor
grad_in
,
int
pool_method
);
int
points_in_boxes_cpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
pts_indices_tensor
);
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
){
// params rois: (N, 7) [x, y, z, w, l, h, ry] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C)
// params argmax: (N, out_x, out_y, out_z, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params pooled_features: (N, out_x, out_y, out_z, C)
// params pool_method: 0: max_pool 1: avg_pool
CHECK_INPUT
(
rois
);
CHECK_INPUT
(
pts
);
CHECK_INPUT
(
pts_feature
);
CHECK_INPUT
(
argmax
);
CHECK_INPUT
(
pts_idx_of_voxels
);
CHECK_INPUT
(
pooled_features
);
int
boxes_num
=
rois
.
size
(
0
);
int
pts_num
=
pts
.
size
(
0
);
int
channels
=
pts_feature
.
size
(
1
);
int
max_pts_each_voxel
=
pts_idx_of_voxels
.
size
(
4
);
// index 0 is the counter
int
out_x
=
pts_idx_of_voxels
.
size
(
1
);
int
out_y
=
pts_idx_of_voxels
.
size
(
2
);
int
out_z
=
pts_idx_of_voxels
.
size
(
3
);
assert
((
out_x
<
256
)
&&
(
out_y
<
256
)
&&
(
out_z
<
256
));
// we encode index with 8bit
const
float
*
rois_data
=
rois
.
data
<
float
>
();
const
float
*
pts_data
=
pts
.
data
<
float
>
();
const
float
*
pts_feature_data
=
pts_feature
.
data
<
float
>
();
int
*
argmax_data
=
argmax
.
data
<
int
>
();
int
*
pts_idx_of_voxels_data
=
pts_idx_of_voxels
.
data
<
int
>
();
float
*
pooled_features_data
=
pooled_features
.
data
<
float
>
();
roiaware_pool3d_launcher
(
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
rois_data
,
pts_data
,
pts_feature_data
,
argmax_data
,
pts_idx_of_voxels_data
,
pooled_features_data
,
pool_method
);
return
1
;
}
int
roiaware_pool3d_gpu_backward
(
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
argmax
,
at
::
Tensor
grad_out
,
at
::
Tensor
grad_in
,
int
pool_method
){
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool 1: avg_pool
CHECK_INPUT
(
pts_idx_of_voxels
);
CHECK_INPUT
(
argmax
);
CHECK_INPUT
(
grad_out
);
CHECK_INPUT
(
grad_in
);
int
boxes_num
=
pts_idx_of_voxels
.
size
(
0
);
int
out_x
=
pts_idx_of_voxels
.
size
(
1
);
int
out_y
=
pts_idx_of_voxels
.
size
(
2
);
int
out_z
=
pts_idx_of_voxels
.
size
(
3
);
int
max_pts_each_voxel
=
pts_idx_of_voxels
.
size
(
4
);
// index 0 is the counter
int
channels
=
grad_out
.
size
(
4
);
const
int
*
pts_idx_of_voxels_data
=
pts_idx_of_voxels
.
data
<
int
>
();
const
int
*
argmax_data
=
argmax
.
data
<
int
>
();
const
float
*
grad_out_data
=
grad_out
.
data
<
float
>
();
float
*
grad_in_data
=
grad_in
.
data
<
float
>
();
roiaware_pool3d_backward_launcher
(
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
max_pts_each_voxel
,
pts_idx_of_voxels_data
,
argmax_data
,
grad_out_data
,
grad_in_data
,
pool_method
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
roiaware_pool3d_gpu
,
"roiaware pool3d forward (CUDA)"
);
m
.
def
(
"backward"
,
&
roiaware_pool3d_gpu_backward
,
"roiaware pool3d backward (CUDA)"
);
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
"points_in_boxes_gpu forward (CUDA)"
);
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
"points_in_boxes_cpu forward (CPU)"
);
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
0 → 100644
View file @
dec83ff9
//Modified from
//https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
//RoI-aware point cloud feature pooling
//Written by Shaoshuai Shi
//All Rights Reserved 2019.
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <assert.h>
#include <math.h>
#include <stdio.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
// #define DEBUG
__device__
inline
void
lidar_to_local_coords
(
float
shift_x
,
float
shift_y
,
float
rz
,
float
&
local_x
,
float
&
local_y
){
// should rotate pi/2 + alpha to translate LiDAR to local
float
rot_angle
=
rz
+
M_PI
/
2
;
float
cosa
=
cos
(
rot_angle
),
sina
=
sin
(
rot_angle
);
local_x
=
shift_x
*
cosa
+
shift_y
*
(
-
sina
);
local_y
=
shift_x
*
sina
+
shift_y
*
cosa
;
}
__device__
inline
int
check_pt_in_box3d
(
const
float
*
pt
,
const
float
*
box3d
,
float
&
local_x
,
float
&
local_y
){
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, w, l, h, rz) in LiDAR coordinate, cz in the bottom center
float
x
=
pt
[
0
],
y
=
pt
[
1
],
z
=
pt
[
2
];
float
cx
=
box3d
[
0
],
cy
=
box3d
[
1
],
cz
=
box3d
[
2
];
float
w
=
box3d
[
3
],
l
=
box3d
[
4
],
h
=
box3d
[
5
],
rz
=
box3d
[
6
];
cz
+=
h
/
2.0
;
// shift to the center since cz in box3d is the bottom center
if
(
fabsf
(
z
-
cz
)
>
h
/
2.0
)
return
0
;
lidar_to_local_coords
(
x
-
cx
,
y
-
cy
,
rz
,
local_x
,
local_y
);
float
in_flag
=
(
local_x
>
-
l
/
2.0
)
&
(
local_x
<
l
/
2.0
)
&
(
local_y
>
-
w
/
2.0
)
&
(
local_y
<
w
/
2.0
);
return
in_flag
;
}
__global__
void
generate_pts_mask_for_box3d
(
int
boxes_num
,
int
pts_num
,
int
out_x
,
int
out_y
,
int
out_z
,
const
float
*
rois
,
const
float
*
pts
,
int
*
pts_mask
){
// params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z]
// params pts_mask: (N, npoints): -1 means point doesnot in this box, otherwise: encode (x_idxs, y_idxs, z_idxs) by binary bit
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
box_idx
=
blockIdx
.
y
;
if
(
pt_idx
>=
pts_num
||
box_idx
>=
boxes_num
)
return
;
pts
+=
pt_idx
*
3
;
rois
+=
box_idx
*
7
;
pts_mask
+=
box_idx
*
pts_num
+
pt_idx
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
check_pt_in_box3d
(
pts
,
rois
,
local_x
,
local_y
);
pts_mask
[
0
]
=
-
1
;
if
(
cur_in_flag
>
0
){
float
local_z
=
pts
[
2
]
-
rois
[
2
];
float
w
=
rois
[
3
],
l
=
rois
[
4
],
h
=
rois
[
5
];
float
x_res
=
l
/
out_x
;
float
y_res
=
w
/
out_y
;
float
z_res
=
h
/
out_z
;
unsigned
int
x_idx
=
int
((
local_x
+
l
/
2
)
/
x_res
);
unsigned
int
y_idx
=
int
((
local_y
+
w
/
2
)
/
y_res
);
unsigned
int
z_idx
=
int
(
local_z
/
z_res
);
x_idx
=
min
(
max
(
x_idx
,
0
),
out_x
-
1
);
y_idx
=
min
(
max
(
y_idx
,
0
),
out_y
-
1
);
z_idx
=
min
(
max
(
z_idx
,
0
),
out_z
-
1
);
unsigned
int
idx_encoding
=
(
x_idx
<<
16
)
+
(
y_idx
<<
8
)
+
z_idx
;
#ifdef DEBUG
printf
(
"mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, %d), res(%.3f, %.3f, %.3f), idx_encoding=%x
\n
"
,
pt_idx
,
pts
[
0
],
pts
[
1
],
pts
[
2
],
local_x
,
local_y
,
local_z
,
x_idx
,
y_idx
,
z_idx
,
x_res
,
y_res
,
z_res
,
idx_encoding
);
#endif
pts_mask
[
0
]
=
idx_encoding
;
}
}
__global__
void
collect_inside_pts_for_box3d
(
int
boxes_num
,
int
pts_num
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
int
*
pts_mask
,
int
*
pts_idx_of_voxels
){
// params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
int
box_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
box_idx
>=
boxes_num
)
return
;
int
max_num_pts
=
max_pts_each_voxel
-
1
;
// index 0 is the counter
pts_idx_of_voxels
+=
box_idx
*
out_x
*
out_y
*
out_z
*
max_pts_each_voxel
;
for
(
int
k
=
0
;
k
<
pts_num
;
k
++
){
if
(
pts_mask
[
box_idx
*
pts_num
+
k
]
!=
-
1
){
unsigned
int
idx_encoding
=
pts_mask
[
box_idx
*
pts_num
+
k
];
unsigned
int
x_idx
=
(
idx_encoding
>>
16
)
&
0xFF
;
unsigned
int
y_idx
=
(
idx_encoding
>>
8
)
&
0xFF
;
unsigned
int
z_idx
=
idx_encoding
&
0xFF
;
unsigned
int
base_offset
=
x_idx
*
out_y
*
out_z
*
max_pts_each_voxel
+
y_idx
*
out_z
*
max_pts_each_voxel
+
z_idx
*
max_pts_each_voxel
;
unsigned
int
cnt
=
pts_idx_of_voxels
[
base_offset
];
if
(
cnt
<
max_num_pts
){
pts_idx_of_voxels
[
base_offset
+
cnt
+
1
]
=
k
;
pts_idx_of_voxels
[
base_offset
]
++
;
}
#ifdef DEBUG
printf
(
"collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x
\n
"
,
k
,
x_idx
,
y_idx
,
z_idx
,
idx_encoding
);
#endif
}
}
}
__global__
void
roiaware_maxpool3d
(
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
float
*
pts_feature
,
const
int
*
pts_idx_of_voxels
,
float
*
pooled_features
,
int
*
argmax
){
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter
// params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int
box_idx
=
blockIdx
.
z
;
int
channel_idx
=
blockIdx
.
y
;
int
voxel_idx_flat
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
x_idx
=
voxel_idx_flat
/
(
out_y
*
out_z
);
int
y_idx
=
(
voxel_idx_flat
-
x_idx
*
(
out_y
*
out_z
))
/
out_z
;
int
z_idx
=
voxel_idx_flat
%
out_z
;
if
(
box_idx
>=
boxes_num
||
channel_idx
>=
channels
||
x_idx
>=
out_x
||
y_idx
>=
out_y
||
z_idx
>=
out_z
)
return
;
#ifdef DEBUG
printf
(
"src pts_idx_of_voxels: (%p, ), argmax: %p
\n
"
,
pts_idx_of_voxels
,
argmax
);
#endif
int
offset_base
=
x_idx
*
out_y
*
out_z
+
y_idx
*
out_z
+
z_idx
;
pts_idx_of_voxels
+=
box_idx
*
out_x
*
out_y
*
out_z
*
max_pts_each_voxel
+
offset_base
*
max_pts_each_voxel
;
pooled_features
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
argmax
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
int
argmax_idx
=
-
1
;
float
max_val
=
-
1e50
;
int
total_pts
=
pts_idx_of_voxels
[
0
];
for
(
int
k
=
1
;
k
<=
total_pts
;
k
++
){
if
(
pts_feature
[
pts_idx_of_voxels
[
k
]
*
channels
+
channel_idx
]
>
max_val
){
max_val
=
pts_feature
[
pts_idx_of_voxels
[
k
]
*
channels
+
channel_idx
];
argmax_idx
=
pts_idx_of_voxels
[
k
];
}
}
if
(
argmax_idx
!=
-
1
){
pooled_features
[
0
]
=
max_val
;
}
argmax
[
0
]
=
argmax_idx
;
#ifdef DEBUG
printf
(
"channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after pts_idx: %p, argmax: (%p, %d)
\n
"
,
channel_idx
,
x_idx
,
y_idx
,
z_idx
,
argmax_idx
,
max_val
,
total_pts
,
pts_idx_of_voxels
,
argmax
,
argmax_idx
);
#endif
}
__global__
void
roiaware_avgpool3d
(
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
float
*
pts_feature
,
const
int
*
pts_idx_of_voxels
,
float
*
pooled_features
){
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), index 0 is the counter
// params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int
box_idx
=
blockIdx
.
z
;
int
channel_idx
=
blockIdx
.
y
;
int
voxel_idx_flat
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
x_idx
=
voxel_idx_flat
/
(
out_y
*
out_z
);
int
y_idx
=
(
voxel_idx_flat
-
x_idx
*
(
out_y
*
out_z
))
/
out_z
;
int
z_idx
=
voxel_idx_flat
%
out_z
;
if
(
box_idx
>=
boxes_num
||
channel_idx
>=
channels
||
x_idx
>=
out_x
||
y_idx
>=
out_y
||
z_idx
>=
out_z
)
return
;
int
offset_base
=
x_idx
*
out_y
*
out_z
+
y_idx
*
out_z
+
z_idx
;
pts_idx_of_voxels
+=
box_idx
*
out_x
*
out_y
*
out_z
*
max_pts_each_voxel
+
offset_base
*
max_pts_each_voxel
;
pooled_features
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
float
sum_val
=
0
;
int
total_pts
=
pts_idx_of_voxels
[
0
];
for
(
int
k
=
1
;
k
<=
total_pts
;
k
++
){
sum_val
+=
pts_feature
[
pts_idx_of_voxels
[
k
]
*
channels
+
channel_idx
];
}
if
(
total_pts
>
0
){
pooled_features
[
0
]
=
sum_val
/
total_pts
;
}
}
void
roiaware_pool3d_launcher
(
int
boxes_num
,
int
pts_num
,
int
channels
,
int
max_pts_each_voxel
,
int
out_x
,
int
out_y
,
int
out_z
,
const
float
*
rois
,
const
float
*
pts
,
const
float
*
pts_feature
,
int
*
argmax
,
int
*
pts_idx_of_voxels
,
float
*
pooled_features
,
int
pool_method
){
// params rois: (N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C)
// params argmax: (N, out_x, out_y, out_z, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params pooled_features: (N, out_x, out_y, out_z, C)
// params pool_method: 0: max_pool 1: avg_pool
int
*
pts_mask
=
NULL
;
cudaMalloc
(
&
pts_mask
,
boxes_num
*
pts_num
*
sizeof
(
int
));
// (N, M)
cudaMemset
(
pts_mask
,
-
1
,
boxes_num
*
pts_num
*
sizeof
(
int
));
dim3
blocks_mask
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
boxes_num
);
dim3
threads
(
THREADS_PER_BLOCK
);
generate_pts_mask_for_box3d
<<<
blocks_mask
,
threads
>>>
(
boxes_num
,
pts_num
,
out_x
,
out_y
,
out_z
,
rois
,
pts
,
pts_mask
);
// TODO: Merge the collect and pool functions, SS
dim3
blocks_collect
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK
));
collect_inside_pts_for_box3d
<<<
blocks_collect
,
threads
>>>
(
boxes_num
,
pts_num
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_mask
,
pts_idx_of_voxels
);
dim3
blocks_pool
(
DIVUP
(
out_x
*
out_y
*
out_z
,
THREADS_PER_BLOCK
),
channels
,
boxes_num
);
if
(
pool_method
==
0
){
roiaware_maxpool3d
<<<
blocks_pool
,
threads
>>>
(
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_feature
,
pts_idx_of_voxels
,
pooled_features
,
argmax
);
}
else
if
(
pool_method
==
1
){
roiaware_avgpool3d
<<<
blocks_pool
,
threads
>>>
(
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_feature
,
pts_idx_of_voxels
,
pooled_features
);
}
cudaFree
(
pts_mask
);
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
__global__
void
roiaware_maxpool3d_backward
(
int
boxes_num
,
int
channels
,
int
out_x
,
int
out_y
,
int
out_z
,
const
int
*
argmax
,
const
float
*
grad_out
,
float
*
grad_in
){
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int
box_idx
=
blockIdx
.
z
;
int
channel_idx
=
blockIdx
.
y
;
int
voxel_idx_flat
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
x_idx
=
voxel_idx_flat
/
(
out_y
*
out_z
);
int
y_idx
=
(
voxel_idx_flat
-
x_idx
*
(
out_y
*
out_z
))
/
out_z
;
int
z_idx
=
voxel_idx_flat
%
out_z
;
if
(
box_idx
>=
boxes_num
||
channel_idx
>=
channels
||
x_idx
>=
out_x
||
y_idx
>=
out_y
||
z_idx
>=
out_z
)
return
;
int
offset_base
=
x_idx
*
out_y
*
out_z
+
y_idx
*
out_z
+
z_idx
;
argmax
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
grad_out
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
if
(
argmax
[
0
]
==
-
1
)
return
;
atomicAdd
(
grad_in
+
argmax
[
0
]
*
channels
+
channel_idx
,
grad_out
[
0
]
*
1
);
}
__global__
void
roiaware_avgpool3d_backward
(
int
boxes_num
,
int
channels
,
int
out_x
,
int
out_y
,
int
out_z
,
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
float
*
grad_out
,
float
*
grad_in
){
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int
box_idx
=
blockIdx
.
z
;
int
channel_idx
=
blockIdx
.
y
;
int
voxel_idx_flat
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
x_idx
=
voxel_idx_flat
/
(
out_y
*
out_z
);
int
y_idx
=
(
voxel_idx_flat
-
x_idx
*
(
out_y
*
out_z
))
/
out_z
;
int
z_idx
=
voxel_idx_flat
%
out_z
;
if
(
box_idx
>=
boxes_num
||
channel_idx
>=
channels
||
x_idx
>=
out_x
||
y_idx
>=
out_y
||
z_idx
>=
out_z
)
return
;
int
offset_base
=
x_idx
*
out_y
*
out_z
+
y_idx
*
out_z
+
z_idx
;
pts_idx_of_voxels
+=
box_idx
*
out_x
*
out_y
*
out_z
*
max_pts_each_voxel
+
offset_base
*
max_pts_each_voxel
;
grad_out
+=
box_idx
*
out_x
*
out_y
*
out_z
*
channels
+
offset_base
*
channels
+
channel_idx
;
int
total_pts
=
pts_idx_of_voxels
[
0
];
float
cur_grad
=
1
/
fmaxf
(
float
(
total_pts
),
1.0
);
for
(
int
k
=
1
;
k
<=
total_pts
;
k
++
){
atomicAdd
(
grad_in
+
pts_idx_of_voxels
[
k
]
*
channels
+
channel_idx
,
grad_out
[
0
]
*
cur_grad
);
}
}
void
roiaware_pool3d_backward_launcher
(
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
channels
,
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
int
*
argmax
,
const
float
*
grad_out
,
float
*
grad_in
,
int
pool_method
){
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool, 1: avg_pool
dim3
blocks
(
DIVUP
(
out_x
*
out_y
*
out_z
,
THREADS_PER_BLOCK
),
channels
,
boxes_num
);
dim3
threads
(
THREADS_PER_BLOCK
);
if
(
pool_method
==
0
){
roiaware_maxpool3d_backward
<<<
blocks
,
threads
>>>
(
boxes_num
,
channels
,
out_x
,
out_y
,
out_z
,
argmax
,
grad_out
,
grad_in
);
}
else
if
(
pool_method
==
1
){
roiaware_avgpool3d_backward
<<<
blocks
,
threads
>>>
(
boxes_num
,
channels
,
out_x
,
out_y
,
out_z
,
max_pts_each_voxel
,
pts_idx_of_voxels
,
grad_out
,
grad_in
);
}
}
setup.py
View file @
dec83ff9
...
@@ -258,6 +258,17 @@ if __name__ == '__main__':
...
@@ -258,6 +258,17 @@ if __name__ == '__main__':
'src/voxelization_cpu.cpp'
,
'src/voxelization_cpu.cpp'
,
'src/voxelization_cuda.cu'
,
'src/voxelization_cuda.cu'
,
]),
]),
make_cuda_ext
(
name
=
'roiaware_pool3d_ext'
,
module
=
'mmdet3d.ops.roiaware_pool3d'
,
sources
=
[
'src/roiaware_pool3d.cpp'
,
'src/points_in_boxes_cpu.cpp'
,
],
sources_cuda
=
[
'src/roiaware_pool3d_kernel.cu'
,
'src/points_in_boxes_cuda.cu'
,
]),
],
],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
cmdclass
=
{
'build_ext'
:
BuildExtension
},
zip_safe
=
False
)
zip_safe
=
False
)
tests/test_roiaware_pool3d.py
0 → 100644
View file @
dec83ff9
import
pytest
import
torch
from
mmdet3d.ops.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
def
test_RoIAwarePool3d
():
if
not
torch
.
cuda
.
is_available
(
):
# RoIAwarePool3d only support gpu version currently.
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
roiaware_pool3d_max
=
RoIAwarePool3d
(
out_size
=
4
,
max_pts_per_voxel
=
128
,
mode
=
'max'
)
roiaware_pool3d_avg
=
RoIAwarePool3d
(
out_size
=
4
,
max_pts_per_voxel
=
128
,
mode
=
'avg'
)
rois
=
torch
.
tensor
(
[[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
,
0.3
],
[
-
10.0
,
23.0
,
16.0
,
10
,
20
,
20
,
0.5
]],
dtype
=
torch
.
float32
).
cuda
(
)
# boxes (m, 7) with bottom center in lidar coordinate
pts
=
torch
.
tensor
(
[
[
1
,
2
,
3.3
],
[
1.2
,
2.5
,
3.0
],
[
0.8
,
2.1
,
3.5
],
[
1.6
,
2.6
,
3.6
],
[
0.8
,
1.2
,
3.9
],
[
-
9.2
,
21.0
,
18.2
],
[
3.8
,
7.9
,
6.3
],
[
4.7
,
3.5
,
-
12.2
],
[
3.8
,
7.6
,
-
2
],
[
-
10.6
,
-
12.9
,
-
20
],
[
-
16
,
-
18
,
9
],
[
-
21.3
,
-
52
,
-
5
],
[
0
,
0
,
0
],
[
6
,
7
,
8
],
[
-
2
,
-
3
,
-
4
],
],
dtype
=
torch
.
float32
).
cuda
()
# points (n, 3) in lidar coordinate
pts_feature
=
pts
.
clone
()
pooled_features_max
=
roiaware_pool3d_max
(
rois
=
rois
,
pts
=
pts
,
pts_feature
=
pts_feature
)
assert
pooled_features_max
.
shape
==
torch
.
Size
([
2
,
4
,
4
,
4
,
3
])
assert
torch
.
allclose
(
pooled_features_max
.
sum
(),
torch
.
tensor
(
51.100
).
cuda
(),
1e-3
)
pooled_features_avg
=
roiaware_pool3d_avg
(
rois
=
rois
,
pts
=
pts
,
pts_feature
=
pts_feature
)
assert
pooled_features_avg
.
shape
==
torch
.
Size
([
2
,
4
,
4
,
4
,
3
])
assert
torch
.
allclose
(
pooled_features_avg
.
sum
(),
torch
.
tensor
(
49.750
).
cuda
(),
1e-3
)
def
test_points_in_boxes_gpu
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
boxes
=
torch
.
tensor
(
[[[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
,
0.3
]],
[[
-
10.0
,
23.0
,
16.0
,
10
,
20
,
20
,
0.5
]]],
dtype
=
torch
.
float32
).
cuda
(
)
# boxes (b, t, 7) with bottom center in lidar coordinate
pts
=
torch
.
tensor
(
[[[
1
,
2
,
3.3
],
[
1.2
,
2.5
,
3.0
],
[
0.8
,
2.1
,
3.5
],
[
1.6
,
2.6
,
3.6
],
[
0.8
,
1.2
,
3.9
],
[
-
9.2
,
21.0
,
18.2
],
[
3.8
,
7.9
,
6.3
],
[
4.7
,
3.5
,
-
12.2
]],
[[
3.8
,
7.6
,
-
2
],
[
-
10.6
,
-
12.9
,
-
20
],
[
-
16
,
-
18
,
9
],
[
-
21.3
,
-
52
,
-
5
],
[
0
,
0
,
0
],
[
6
,
7
,
8
],
[
-
2
,
-
3
,
-
4
],
[
6
,
4
,
9
]]],
dtype
=
torch
.
float32
).
cuda
()
# points (b, m, 3) in lidar coordinate
point_indices
=
points_in_boxes_gpu
(
points
=
pts
,
boxes
=
boxes
)
expected_point_indices
=
torch
.
tensor
(
[[
0
,
0
,
0
,
0
,
0
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]],
dtype
=
torch
.
int32
).
cuda
()
assert
point_indices
.
shape
==
torch
.
Size
([
2
,
8
])
assert
(
point_indices
==
expected_point_indices
).
all
()
def
test_points_in_boxes_cpu
():
boxes
=
torch
.
tensor
(
[[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
,
0.3
],
[
-
10.0
,
23.0
,
16.0
,
10
,
20
,
20
,
0.5
]],
dtype
=
torch
.
float32
)
# boxes (m, 7) with bottom center in lidar coordinate
pts
=
torch
.
tensor
(
[
[
1
,
2
,
3.3
],
[
1.2
,
2.5
,
3.0
],
[
0.8
,
2.1
,
3.5
],
[
1.6
,
2.6
,
3.6
],
[
0.8
,
1.2
,
3.9
],
[
-
9.2
,
21.0
,
18.2
],
[
3.8
,
7.9
,
6.3
],
[
4.7
,
3.5
,
-
12.2
],
[
3.8
,
7.6
,
-
2
],
[
-
10.6
,
-
12.9
,
-
20
],
[
-
16
,
-
18
,
9
],
[
-
21.3
,
-
52
,
-
5
],
[
0
,
0
,
0
],
[
6
,
7
,
8
],
[
-
2
,
-
3
,
-
4
],
],
dtype
=
torch
.
float32
)
# points (n, 3) in lidar coordinate
point_indices
=
points_in_boxes_cpu
(
points
=
pts
,
boxes
=
boxes
)
expected_point_indices
=
torch
.
tensor
(
[[
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]],
dtype
=
torch
.
int32
)
assert
point_indices
.
shape
==
torch
.
Size
([
2
,
15
])
assert
(
point_indices
==
expected_point_indices
).
all
()
if
__name__
==
'__main__'
:
test_points_in_boxes_cpu
()
test_points_in_boxes_gpu
()
test_RoIAwarePool3d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment