Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
fa39f1c5
Commit
fa39f1c5
authored
Dec 26, 2021
by
Shaoshuai Shi
Browse files
Support sectorized-proposal-centric (SPC) keypoint sampling
parent
8922371e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
222 additions
and
44 deletions
+222
-44
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
+222
-44
No files found.
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
View file @
fa39f1c5
import
math
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
...
@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
return
ans
return
ans
def
sample_points_with_roi
(
rois
,
points
,
sample_radius_with_roi
,
num_max_points_of_part
=
200000
):
"""
Args:
rois: (M, 7 + C)
points: (N, 3)
sample_radius_with_roi:
num_max_points_of_part:
Returns:
sampled_points: (N_out, 3)
"""
if
points
.
shape
[
0
]
<
num_max_points_of_part
:
distance
=
(
points
[:,
None
,
:]
-
rois
[
None
,
:,
0
:
3
]).
norm
(
dim
=-
1
)
min_dis
,
min_dis_roi_idx
=
distance
.
min
(
dim
=-
1
)
roi_max_dim
=
(
rois
[
min_dis_roi_idx
,
3
:
6
]
/
2
).
norm
(
dim
=-
1
)
point_mask
=
min_dis
<
roi_max_dim
+
sample_radius_with_roi
else
:
start_idx
=
0
point_mask_list
=
[]
while
start_idx
<
points
.
shape
[
0
]:
distance
=
(
points
[
start_idx
:
start_idx
+
num_max_points_of_part
,
None
,
:]
-
rois
[
None
,
:,
0
:
3
]).
norm
(
dim
=-
1
)
min_dis
,
min_dis_roi_idx
=
distance
.
min
(
dim
=-
1
)
roi_max_dim
=
(
rois
[
min_dis_roi_idx
,
3
:
6
]
/
2
).
norm
(
dim
=-
1
)
cur_point_mask
=
min_dis
<
roi_max_dim
+
sample_radius_with_roi
point_mask_list
.
append
(
cur_point_mask
)
start_idx
+=
num_max_points_of_part
point_mask
=
torch
.
cat
(
point_mask_list
,
dim
=
0
)
sampled_points
=
points
[:
1
]
if
point_mask
.
sum
()
==
0
else
points
[
point_mask
,
:]
return
sampled_points
,
point_mask
def
sector_fps
(
points
,
num_sampled_points
,
num_sectors
):
"""
Args:
points: (N, 3)
num_sampled_points: int
num_sectors: int
Returns:
sampled_points: (N_out, 3)
"""
sector_size
=
np
.
pi
*
2
/
num_sectors
point_angles
=
torch
.
atan2
(
points
[:,
1
],
points
[:,
0
])
+
np
.
pi
sector_idx
=
(
point_angles
/
sector_size
).
floor
().
clamp
(
min
=
0
,
max
=
num_sectors
)
xyz_points_list
=
[]
xyz_batch_cnt
=
[]
num_sampled_points_list
=
[]
for
k
in
range
(
num_sectors
):
mask
=
(
sector_idx
==
k
)
cur_num_points
=
mask
.
sum
().
item
()
if
cur_num_points
>
0
:
xyz_points_list
.
append
(
points
[
mask
])
xyz_batch_cnt
.
append
(
cur_num_points
)
ratio
=
cur_num_points
/
points
.
shape
[
0
]
num_sampled_points_list
.
append
(
min
(
cur_num_points
,
math
.
ceil
(
ratio
*
num_sampled_points
))
)
if
len
(
xyz_batch_cnt
)
==
0
:
xyz_points_list
.
append
(
points
)
xyz_batch_cnt
.
append
(
len
(
points
))
num_sampled_points_list
.
append
(
num_sampled_points
)
print
(
f
'Warning: empty sector points detected in SectorFPS: points.shape=
{
points
.
shape
}
'
)
xyz
=
torch
.
cat
(
xyz_points_list
,
dim
=
0
)
xyz_batch_cnt
=
torch
.
tensor
(
xyz_batch_cnt
,
device
=
points
.
device
).
int
()
sampled_points_batch_cnt
=
torch
.
tensor
(
num_sampled_points_list
,
device
=
points
.
device
).
int
()
sampled_pt_idxs
=
pointnet2_stack_utils
.
stack_farthest_point_sample
(
xyz
.
contiguous
(),
xyz_batch_cnt
,
sampled_points_batch_cnt
).
long
()
sampled_points
=
xyz
[
sampled_pt_idxs
]
return
sampled_points
class
VoxelSetAbstraction
(
nn
.
Module
):
class
VoxelSetAbstraction
(
nn
.
Module
):
def
__init__
(
self
,
model_cfg
,
voxel_size
,
point_cloud_range
,
num_bev_features
=
None
,
def
__init__
(
self
,
model_cfg
,
voxel_size
,
point_cloud_range
,
num_bev_features
=
None
,
num_rawpoint_features
=
None
,
**
kwargs
):
num_rawpoint_features
=
None
,
**
kwargs
):
...
@@ -100,23 +181,64 @@ class VoxelSetAbstraction(nn.Module):
...
@@ -100,23 +181,64 @@ class VoxelSetAbstraction(nn.Module):
self
.
num_point_features_before_fusion
=
c_in
self
.
num_point_features_before_fusion
=
c_in
def
interpolate_from_bev_features
(
self
,
keypoints
,
bev_features
,
batch_size
,
bev_stride
):
def
interpolate_from_bev_features
(
self
,
keypoints
,
bev_features
,
batch_size
,
bev_stride
):
x_idxs
=
(
keypoints
[:,
:,
0
]
-
self
.
point_cloud_range
[
0
])
/
self
.
voxel_size
[
0
]
"""
y_idxs
=
(
keypoints
[:,
:,
1
]
-
self
.
point_cloud_range
[
1
])
/
self
.
voxel_size
[
1
]
Args:
keypoints: (N1 + N2 + ..., 4)
bev_features: (B, C, H, W)
batch_size:
bev_stride:
Returns:
point_bev_features: (N1 + N2 + ..., C)
"""
x_idxs
=
(
keypoints
[:,
1
]
-
self
.
point_cloud_range
[
0
])
/
self
.
voxel_size
[
0
]
y_idxs
=
(
keypoints
[:,
2
]
-
self
.
point_cloud_range
[
1
])
/
self
.
voxel_size
[
1
]
x_idxs
=
x_idxs
/
bev_stride
x_idxs
=
x_idxs
/
bev_stride
y_idxs
=
y_idxs
/
bev_stride
y_idxs
=
y_idxs
/
bev_stride
point_bev_features_list
=
[]
point_bev_features_list
=
[]
for
k
in
range
(
batch_size
):
for
k
in
range
(
batch_size
):
cur_x_idxs
=
x_idxs
[
k
]
bs_mask
=
(
keypoints
[:,
0
]
==
k
)
cur_y_idxs
=
y_idxs
[
k
]
cur_x_idxs
=
x_idxs
[
bs_mask
]
cur_y_idxs
=
y_idxs
[
bs_mask
]
cur_bev_features
=
bev_features
[
k
].
permute
(
1
,
2
,
0
)
# (H, W, C)
cur_bev_features
=
bev_features
[
k
].
permute
(
1
,
2
,
0
)
# (H, W, C)
point_bev_features
=
bilinear_interpolate_torch
(
cur_bev_features
,
cur_x_idxs
,
cur_y_idxs
)
point_bev_features
=
bilinear_interpolate_torch
(
cur_bev_features
,
cur_x_idxs
,
cur_y_idxs
)
point_bev_features_list
.
append
(
point_bev_features
.
unsqueeze
(
dim
=
0
)
)
point_bev_features_list
.
append
(
point_bev_features
)
point_bev_features
=
torch
.
cat
(
point_bev_features_list
,
dim
=
0
)
# (
B, N
, C
0
)
point_bev_features
=
torch
.
cat
(
point_bev_features_list
,
dim
=
0
)
# (
N1 + N2 + ...
, C)
return
point_bev_features
return
point_bev_features
def
sectorized_proposal_centric_sampling
(
self
,
roi_boxes
,
points
):
"""
Args:
roi_boxes: (M, 7 + C)
points: (N, 3)
Returns:
sampled_points: (N_out, 3)
"""
sampled_points
,
_
=
sample_points_with_roi
(
rois
=
roi_boxes
,
points
=
points
,
sample_radius_with_roi
=
self
.
model_cfg
.
SPC
.
SAMPLE_RADIUS_WITH_ROI
,
num_max_points_of_part
=
self
.
model_cfg
.
SPC
.
get
(
'NUM_POINTS_OF_EACH_SAMPLE_PART'
,
200000
)
)
sampled_points
=
sector_fps
(
points
=
sampled_points
,
num_sampled_points
=
self
.
model_cfg
.
NUM_KEYPOINTS
,
num_sectors
=
self
.
model_cfg
.
SPC
.
NUM_SECTORS
)
return
sampled_points
def
get_sampled_points
(
self
,
batch_dict
):
def
get_sampled_points
(
self
,
batch_dict
):
"""
Args:
batch_dict:
Returns:
keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z]
"""
batch_size
=
batch_dict
[
'batch_size'
]
batch_size
=
batch_dict
[
'batch_size'
]
if
self
.
model_cfg
.
POINT_SOURCE
==
'raw_points'
:
if
self
.
model_cfg
.
POINT_SOURCE
==
'raw_points'
:
src_points
=
batch_dict
[
'points'
][:,
1
:
4
]
src_points
=
batch_dict
[
'points'
][:,
1
:
4
]
...
@@ -147,16 +269,75 @@ class VoxelSetAbstraction(nn.Module):
...
@@ -147,16 +269,75 @@ class VoxelSetAbstraction(nn.Module):
keypoints
=
sampled_points
[
0
][
cur_pt_idxs
[
0
]].
unsqueeze
(
dim
=
0
)
keypoints
=
sampled_points
[
0
][
cur_pt_idxs
[
0
]].
unsqueeze
(
dim
=
0
)
elif
self
.
model_cfg
.
SAMPLE_METHOD
==
'FastFPS'
:
elif
self
.
model_cfg
.
SAMPLE_METHOD
==
'SPC'
:
raise
NotImplementedError
cur_keypoints
=
self
.
sectorized_proposal_centric_sampling
(
roi_boxes
=
batch_dict
[
'rois'
][
bs_idx
],
points
=
sampled_points
)
bs_idxs
=
cur_keypoints
.
new_ones
(
cur_keypoints
.
shape
[
0
])
*
bs_idx
keypoints
=
torch
.
cat
((
bs_idxs
[:,
None
],
cur_keypoints
),
dim
=
1
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
keypoints_list
.
append
(
keypoints
)
keypoints_list
.
append
(
keypoints
)
keypoints
=
torch
.
cat
(
keypoints_list
,
dim
=
0
)
# (B, M, 3)
keypoints
=
torch
.
cat
(
keypoints_list
,
dim
=
0
)
# (B, M, 3) or (N1 + N2 + ..., 4)
if
len
(
keypoints
.
shape
)
==
3
:
batch_idx
=
torch
.
arange
(
batch_size
,
device
=
keypoints
.
device
).
view
(
-
1
,
1
).
repeat
(
1
,
keypoints
.
shape
[
1
]).
view
(
-
1
,
1
)
keypoints
=
torch
.
cat
((
batch_idx
.
float
(),
keypoints
.
view
(
-
1
,
3
)),
dim
=
1
)
return
keypoints
return
keypoints
@
staticmethod
def
aggregate_keypoint_features_from_one_source
(
batch_size
,
aggregate_func
,
xyz
,
xyz_features
,
xyz_bs_idxs
,
new_xyz
,
new_xyz_batch_cnt
,
filter_neighbors_with_roi
=
False
,
radius_of_neighbor
=
None
,
num_max_points_of_part
=
None
,
rois
=
None
):
"""
Args:
aggregate_func:
xyz: (N, 3)
xyz_features: (N, C)
xyz_bs_idxs: (N)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size), [N1, N2, ...]
filter_neighbors_with_roi: True/False
radius_of_neighbor: float
num_max_points_of_part: int
rois: (batch_size, num_rois, 7 + C)
Returns:
"""
xyz_batch_cnt
=
xyz
.
new_zeros
(
batch_size
).
int
()
if
filter_neighbors_with_roi
:
point_features
=
torch
.
cat
((
xyz
,
xyz_features
),
dim
=
0
)
if
xyz_features
is
not
None
else
xyz
point_features_list
=
[]
for
bs_idx
in
range
(
batch_size
):
bs_mask
=
(
xyz_bs_idxs
==
bs_idx
)
_
,
valid_mask
=
sample_points_with_roi
(
rois
=
rois
[
bs_idx
],
points
=
xyz
[
bs_mask
],
sample_radius_with_roi
=
radius_of_neighbor
,
num_max_points_of_part
=
num_max_points_of_part
,
)
point_features_list
.
append
(
point_features
[
bs_mask
][
valid_mask
])
xyz_batch_cnt
[
bs_idx
]
=
valid_mask
.
sum
()
valid_point_features
=
torch
.
cat
(
point_features_list
,
dim
=
0
)
xyz
=
valid_point_features
[:,
0
:
3
]
xyz_features
=
valid_point_features
[:,
3
:]
if
xyz_features
is
not
None
else
None
else
:
for
bs_idx
in
range
(
batch_size
):
xyz_batch_cnt
[
bs_idx
]
=
(
xyz_bs_idxs
==
bs_idx
).
sum
()
pooled_points
,
pooled_features
=
aggregate_func
(
xyz
=
xyz
.
contiguous
(),
xyz_batch_cnt
=
xyz_batch_cnt
,
new_xyz
=
new_xyz
,
new_xyz_batch_cnt
=
new_xyz_batch_cnt
,
features
=
xyz_features
,
)
return
pooled_features
def
forward
(
self
,
batch_dict
):
def
forward
(
self
,
batch_dict
):
"""
"""
Args:
Args:
...
@@ -185,56 +366,53 @@ class VoxelSetAbstraction(nn.Module):
...
@@ -185,56 +366,53 @@ class VoxelSetAbstraction(nn.Module):
)
)
point_features_list
.
append
(
point_bev_features
)
point_features_list
.
append
(
point_bev_features
)
batch_size
,
num_keypoints
,
_
=
keypoints
.
shape
batch_size
=
batch_dict
[
'batch_size'
]
new_xyz
=
keypoints
.
view
(
-
1
,
3
)
new_xyz_batch_cnt
=
new_xyz
.
new_zeros
(
batch_size
).
int
().
fill_
(
num_keypoints
)
new_xyz
=
keypoints
[:,
1
:
4
].
contiguous
()
new_xyz_batch_cnt
=
new_xyz
.
new_zeros
(
batch_size
).
int
()
for
k
in
range
(
batch_size
):
new_xyz_batch_cnt
[
k
]
=
(
keypoints
[:,
0
]
==
k
).
sum
()
if
'raw_points'
in
self
.
model_cfg
.
FEATURES_SOURCE
:
if
'raw_points'
in
self
.
model_cfg
.
FEATURES_SOURCE
:
raw_points
=
batch_dict
[
'points'
]
raw_points
=
batch_dict
[
'points'
]
xyz
=
raw_points
[:,
1
:
4
]
xyz_batch_cnt
=
xyz
.
new_zeros
(
batch_size
).
int
()
pooled_features
=
self
.
aggregate_keypoint_features_from_one_source
(
for
bs_idx
in
range
(
batch_size
):
batch_size
=
batch_size
,
aggregate_func
=
self
.
SA_rawpoints
,
xyz_batch_cnt
[
bs_idx
]
=
(
raw_points
[:,
0
]
==
bs_idx
).
sum
()
xyz
=
raw_points
[:,
1
:
4
],
point_features
=
raw_points
[:,
4
:].
contiguous
()
if
raw_points
.
shape
[
1
]
>
4
else
None
xyz_features
=
raw_points
[:,
4
:].
contiguous
()
if
raw_points
.
shape
[
1
]
>
4
else
None
,
xyz_bs_idxs
=
raw_points
[:,
0
],
pooled_points
,
pooled_features
=
self
.
SA_rawpoints
(
new_xyz
=
new_xyz
,
new_xyz_batch_cnt
=
new_xyz_batch_cnt
,
xyz
=
xyz
.
contiguous
(),
filter_neighbors_with_roi
=
self
.
model_cfg
.
SA_LAYER
[
'raw_points'
].
get
(
'FILTER_NEIGHBOR_WITH_ROI'
,
False
),
xyz_batch_cnt
=
xyz_batch_cnt
,
radius_of_neighbor
=
self
.
model_cfg
.
SA_LAYER
[
'raw_points'
].
get
(
'RADIUS_OF_NEIGHBOR_WITH_ROI'
,
None
),
new_xyz
=
new_xyz
,
rois
=
batch_dict
.
get
(
'rois'
,
None
)
new_xyz_batch_cnt
=
new_xyz_batch_cnt
,
features
=
point_features
,
)
)
point_features_list
.
append
(
pooled_features
.
view
(
batch_size
,
num_keypoints
,
-
1
)
)
point_features_list
.
append
(
pooled_features
)
for
k
,
src_name
in
enumerate
(
self
.
SA_layer_names
):
for
k
,
src_name
in
enumerate
(
self
.
SA_layer_names
):
cur_coords
=
batch_dict
[
'multi_scale_3d_features'
][
src_name
].
indices
cur_coords
=
batch_dict
[
'multi_scale_3d_features'
][
src_name
].
indices
cur_features
=
batch_dict
[
'multi_scale_3d_features'
][
src_name
].
features
.
contiguous
()
xyz
=
common_utils
.
get_voxel_centers
(
xyz
=
common_utils
.
get_voxel_centers
(
cur_coords
[:,
1
:
4
],
cur_coords
[:,
1
:
4
],
downsample_times
=
self
.
downsample_times_map
[
src_name
],
downsample_times
=
self
.
downsample_times_map
[
src_name
],
voxel_size
=
self
.
voxel_size
,
point_cloud_range
=
self
.
point_cloud_range
voxel_size
=
self
.
voxel_size
,
point_cloud_range
=
self
.
point_cloud_range
)
)
xyz_batch_cnt
=
xyz
.
new_zeros
(
batch_size
).
int
()
for
bs_idx
in
range
(
batch_size
):
pooled_features
=
self
.
aggregate_keypoint_features_from_one_source
(
xyz_batch_cnt
[
bs_idx
]
=
(
cur_coords
[:,
0
]
==
bs_idx
).
sum
()
batch_size
=
batch_size
,
aggregate_func
=
self
.
SA_layers
[
k
],
xyz
=
xyz
.
contiguous
(),
xyz_features
=
cur_features
,
xyz_bs_idxs
=
cur_coords
[:,
0
],
pooled_points
,
pooled_features
=
self
.
SA_layers
[
k
](
new_xyz
=
new_xyz
,
new_xyz_batch_cnt
=
new_xyz_batch_cnt
,
xyz
=
xyz
.
contiguous
(),
filter_neighbors_with_roi
=
self
.
model_cfg
.
SA_LAYER
[
src_name
].
get
(
'FILTER_NEIGHBOR_WITH_ROI'
,
False
),
xyz_batch_cnt
=
xyz_batch_cnt
,
radius_of_neighbor
=
self
.
model_cfg
.
SA_LAYER
[
src_name
].
get
(
'RADIUS_OF_NEIGHBOR_WITH_ROI'
,
None
),
new_xyz
=
new_xyz
,
rois
=
batch_dict
.
get
(
'rois'
,
None
)
new_xyz_batch_cnt
=
new_xyz_batch_cnt
,
features
=
batch_dict
[
'multi_scale_3d_features'
][
src_name
].
features
.
contiguous
(),
)
)
point_features_list
.
append
(
pooled_features
.
view
(
batch_size
,
num_keypoints
,
-
1
))
point_features
=
torch
.
cat
(
point_features_list
,
dim
=
2
)
point_features
_list
.
append
(
pooled_features
)
batch_idx
=
torch
.
arange
(
batch_size
,
device
=
keypoints
.
device
).
view
(
-
1
,
1
).
repeat
(
1
,
keypoints
.
shape
[
1
]).
view
(
-
1
)
point_features
=
torch
.
cat
(
point_features_list
,
dim
=-
1
)
point_coords
=
torch
.
cat
((
batch_idx
.
view
(
-
1
,
1
).
float
(),
keypoints
.
view
(
-
1
,
3
)),
dim
=
1
)
batch_dict
[
'point_features_before_fusion'
]
=
point_features
.
view
(
-
1
,
point_features
.
shape
[
-
1
])
batch_dict
[
'point_features_before_fusion'
]
=
point_features
.
view
(
-
1
,
point_features
.
shape
[
-
1
])
point_features
=
self
.
vsa_point_feature_fusion
(
point_features
.
view
(
-
1
,
point_features
.
shape
[
-
1
]))
point_features
=
self
.
vsa_point_feature_fusion
(
point_features
.
view
(
-
1
,
point_features
.
shape
[
-
1
]))
batch_dict
[
'point_features'
]
=
point_features
# (BxN, C)
batch_dict
[
'point_features'
]
=
point_features
# (BxN, C)
batch_dict
[
'point_coords'
]
=
point
_coord
s
# (BxN, 4)
batch_dict
[
'point_coords'
]
=
key
points
# (BxN, 4)
return
batch_dict
return
batch_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment