Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f27d308f
Commit
f27d308f
authored
Jun 07, 2020
by
yinchimaoliang
Browse files
merge master
parents
c66ae813
27ebcfac
Changes
80
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1156 additions
and
379 deletions
+1156
-379
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+34
-4
mmdet3d/datasets/pipelines/point_seg_class_mapping.py
mmdet3d/datasets/pipelines/point_seg_class_mapping.py
+36
-0
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+2
-1
mmdet3d/datasets/sunrgbd_dataset.py
mmdet3d/datasets/sunrgbd_dataset.py
+2
-1
mmdet3d/models/__init__.py
mmdet3d/models/__init__.py
+1
-0
mmdet3d/models/dense_heads/__init__.py
mmdet3d/models/dense_heads/__init__.py
+2
-1
mmdet3d/models/dense_heads/vote_head.py
mmdet3d/models/dense_heads/vote_head.py
+518
-0
mmdet3d/models/detectors/__init__.py
mmdet3d/models/detectors/__init__.py
+2
-1
mmdet3d/models/detectors/votenet.py
mmdet3d/models/detectors/votenet.py
+110
-0
mmdet3d/models/losses/__init__.py
mmdet3d/models/losses/__init__.py
+5
-1
mmdet3d/models/losses/chamfer_distance.py
mmdet3d/models/losses/chamfer_distance.py
+119
-0
mmdet3d/models/model_utils/__init__.py
mmdet3d/models/model_utils/__init__.py
+3
-0
mmdet3d/models/model_utils/vote_module.py
mmdet3d/models/model_utils/vote_module.py
+11
-50
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
+14
-74
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+3
-4
mmdet3d/ops/ball_query/src/ball_query.cpp
mmdet3d/ops/ball_query/src/ball_query.cpp
+27
-20
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
+55
-50
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp
...d/ops/furthest_point_sample/src/furthest_point_sample.cpp
+20
-15
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cuda.cu
...s/furthest_point_sample/src/furthest_point_sample_cuda.cu
+155
-130
mmdet3d/ops/gather_points/src/gather_points.cpp
mmdet3d/ops/gather_points/src/gather_points.cpp
+37
-27
No files found.
mmdet3d/datasets/pipelines/loading.py
View file @
f27d308f
...
...
@@ -42,9 +42,40 @@ class LoadMultiViewImageFromFiles(object):
@
PIPELINES
.
register_module
()
class
LoadPointsFromMultiSweeps
(
object
):
"""Load points from multiple sweeps
def
__init__
(
self
,
sweeps_num
=
10
):
This is usually used for nuScenes dataset to utilize previous sweeps.
Args:
sweeps_num (int): number of sweeps
load_dim (int): dimension number of the loaded points
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details.
"""
def
__init__
(
self
,
sweeps_num
=
10
,
load_dim
=
5
,
file_client_args
=
dict
(
backend
=
'disk'
)):
self
.
load_dim
=
load_dim
self
.
sweeps_num
=
sweeps_num
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
def
_load_points
(
self
,
pts_filename
):
if
self
.
file_client
is
None
:
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
try
:
pts_bytes
=
self
.
file_client
.
get
(
pts_filename
)
points
=
np
.
frombuffer
(
pts_bytes
,
dtype
=
np
.
float32
)
except
ConnectionError
:
mmcv
.
check_file_exist
(
pts_filename
)
if
pts_filename
.
endswith
(
'.npy'
):
points
=
np
.
load
(
pts_filename
)
else
:
points
=
np
.
fromfile
(
pts_filename
,
dtype
=
np
.
float32
)
return
points
def
__call__
(
self
,
results
):
points
=
results
[
'points'
]
...
...
@@ -56,9 +87,8 @@ class LoadPointsFromMultiSweeps(object):
for
idx
,
sweep
in
enumerate
(
results
[
'sweeps'
]):
if
idx
>=
self
.
sweeps_num
:
break
points_sweep
=
np
.
fromfile
(
sweep
[
'data_path'
],
dtype
=
np
.
float32
,
count
=-
1
).
reshape
([
-
1
,
5
])
points_sweep
=
self
.
_load_points
(
sweep
[
'data_path'
])
points_sweep
=
np
.
copy
(
points_sweep
).
reshape
(
-
1
,
self
.
load_dim
)
sweep_ts
=
sweep
[
'timestamp'
]
/
1e6
points_sweep
[:,
3
]
/=
255
points_sweep
[:,
:
3
]
=
points_sweep
[:,
:
3
]
@
sweep
[
...
...
mmdet3d/datasets/pipelines/point_seg_class_mapping.py
0 → 100644
View file @
f27d308f
from
mmdet.datasets.builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
PointSegClassMapping
(
object
):
"""Map original semantic class to valid category ids.
Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids).
Args:
valid_cat_ids (tuple[int): A tuple of valid category.
"""
def
__init__
(
self
,
valid_cat_ids
):
self
.
valid_cat_ids
=
valid_cat_ids
def
__call__
(
self
,
results
):
assert
'pts_semantic_mask'
in
results
pts_semantic_mask
=
results
[
'pts_semantic_mask'
]
neg_cls
=
len
(
self
.
valid_cat_ids
)
for
i
in
range
(
pts_semantic_mask
.
shape
[
0
]):
if
pts_semantic_mask
[
i
]
in
self
.
valid_cat_ids
:
converted_id
=
self
.
valid_cat_ids
.
index
(
pts_semantic_mask
[
i
])
pts_semantic_mask
[
i
]
=
converted_id
else
:
pts_semantic_mask
[
i
]
=
neg_cls
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
'(valid_cat_ids={})'
.
format
(
self
.
valid_cat_ids
)
return
repr_str
mmdet3d/datasets/scannet_dataset.py
View file @
f27d308f
...
...
@@ -20,9 +20,10 @@ class ScanNetDataset(Custom3DDataset):
pipeline
=
None
,
classes
=
None
,
modality
=
None
,
filter_empty_gt
=
True
,
test_mode
=
False
):
super
().
__init__
(
data_root
,
ann_file
,
pipeline
,
classes
,
modality
,
test_mode
)
filter_empty_gt
,
test_mode
)
def
get_ann_info
(
self
,
index
):
# Use index to get the annos, thus the evalhook could also use this api
...
...
mmdet3d/datasets/sunrgbd_dataset.py
View file @
f27d308f
...
...
@@ -16,9 +16,10 @@ class SUNRGBDDataset(Custom3DDataset):
pipeline
=
None
,
classes
=
None
,
modality
=
None
,
filter_empty_gt
=
True
,
test_mode
=
False
):
super
().
__init__
(
data_root
,
ann_file
,
pipeline
,
classes
,
modality
,
test_mode
)
filter_empty_gt
,
test_mode
)
def
get_ann_info
(
self
,
index
):
# Use index to get the annos, thus the evalhook could also use this api
...
...
mmdet3d/models/__init__.py
View file @
f27d308f
...
...
@@ -8,6 +8,7 @@ from .detectors import * # noqa: F401,F403
from
.fusion_layers
import
*
# noqa: F401,F403
from
.losses
import
*
# noqa: F401,F403
from
.middle_encoders
import
*
# noqa: F401,F403
from
.model_utils
import
*
# noqa: F401,F403
from
.necks
import
*
# noqa: F401,F403
from
.registry
import
FUSION_LAYERS
,
MIDDLE_ENCODERS
,
VOXEL_ENCODERS
from
.roi_heads
import
*
# noqa: F401,F403
...
...
mmdet3d/models/dense_heads/__init__.py
View file @
f27d308f
from
.anchor3d_head
import
Anchor3DHead
from
.parta2_rpn_head
import
PartA2RPNHead
from
.vote_head
import
VoteHead
__all__
=
[
'Anchor3DHead'
,
'PartA2RPNHead'
]
__all__
=
[
'Anchor3DHead'
,
'PartA2RPNHead'
,
'VoteHead'
]
mmdet3d/models/dense_heads/vote_head.py
0 → 100644
View file @
f27d308f
This diff is collapsed.
Click to expand it.
mmdet3d/models/detectors/__init__.py
View file @
f27d308f
...
...
@@ -4,10 +4,11 @@ from .mvx_faster_rcnn import (DynamicMVXFasterRCNN, DynamicMVXFasterRCNNV2,
from
.mvx_single_stage
import
MVXSingleStageDetector
from
.mvx_two_stage
import
MVXTwoStageDetector
from
.parta2
import
PartA2
from
.votenet
import
VoteNet
from
.voxelnet
import
DynamicVoxelNet
,
VoxelNet
__all__
=
[
'BaseDetector'
,
'VoxelNet'
,
'DynamicVoxelNet'
,
'MVXSingleStageDetector'
,
'MVXTwoStageDetector'
,
'DynamicMVXFasterRCNN'
,
'DynamicMVXFasterRCNNV2'
,
'DynamicMVXFasterRCNNV3'
,
'PartA2'
'DynamicMVXFasterRCNNV3'
,
'PartA2'
,
'VoteNet'
]
mmdet3d/models/detectors/votenet.py
0 → 100644
View file @
f27d308f
import
torch
from
mmdet3d.core
import
bbox3d2result
from
mmdet.models
import
DETECTORS
,
SingleStageDetector
@
DETECTORS
.
register_module
()
class
VoteNet
(
SingleStageDetector
):
"""VoteNet model.
https://arxiv.org/pdf/1904.09664.pdf
"""
def
__init__
(
self
,
backbone
,
bbox_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
super
(
VoteNet
,
self
).
__init__
(
backbone
=
backbone
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
def
extract_feat
(
self
,
points
):
x
=
self
.
backbone
(
points
)
if
self
.
with_neck
:
x
=
self
.
neck
(
x
)
return
x
def
forward_train
(
self
,
points
,
img_meta
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
=
None
,
pts_instance_mask
=
None
,
gt_bboxes_ignore
=
None
):
"""Forward of training.
Args:
points (list[Tensor]): Points of each batch.
img_meta (list): Image metas.
gt_bboxes_3d (list[Tensor]): gt bboxes of each batch.
gt_labels_3d (list[Tensor]): gt class labels of each batch.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding.
Returns:
dict: Losses.
"""
points_cat
=
torch
.
stack
(
points
)
# tmp
x
=
self
.
extract_feat
(
points_cat
)
bbox_preds
=
self
.
bbox_head
(
x
,
self
.
train_cfg
.
sample_mod
)
loss_inputs
=
(
points
,
gt_bboxes_3d
,
gt_labels_3d
,
pts_semantic_mask
,
pts_instance_mask
,
img_meta
)
losses
=
self
.
bbox_head
.
loss
(
bbox_preds
,
*
loss_inputs
,
gt_bboxes_ignore
=
gt_bboxes_ignore
)
return
losses
def
forward_test
(
self
,
**
kwargs
):
return
self
.
simple_test
(
**
kwargs
)
def
forward
(
self
,
return_loss
=
True
,
**
kwargs
):
if
return_loss
:
return
self
.
forward_train
(
**
kwargs
)
else
:
return
self
.
forward_test
(
**
kwargs
)
def
simple_test
(
self
,
points
,
img_meta
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
pts_semantic_mask
=
None
,
pts_instance_mask
=
None
,
rescale
=
False
):
"""Forward of testing.
Args:
points (list[Tensor]): Points of each sample.
img_meta (list): Image metas.
gt_bboxes_3d (list[Tensor]): gt bboxes of each sample.
gt_labels_3d (list[Tensor]): gt class labels of each sample.
pts_semantic_mask (None | list[Tensor]): point-wise semantic
label of each sample.
pts_instance_mask (None | list[Tensor]): point-wise instance
label of each sample.
rescale (bool): Whether to rescale results.
Returns:
list: Predicted 3d boxes.
"""
points_cat
=
torch
.
stack
(
points
)
# tmp
x
=
self
.
extract_feat
(
points_cat
)
bbox_preds
=
self
.
bbox_head
(
x
,
self
.
test_cfg
.
sample_mod
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
points_cat
,
bbox_preds
,
img_meta
,
rescale
=
rescale
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
return
bbox_results
[
0
]
mmdet3d/models/losses/__init__.py
View file @
f27d308f
from
mmdet.models.losses
import
FocalLoss
,
SmoothL1Loss
,
binary_cross_entropy
from
.chamfer_distance
import
ChamferDistance
,
chamfer_distance
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
,
'binary_cross_entropy'
]
__all__
=
[
'FocalLoss'
,
'SmoothL1Loss'
,
'binary_cross_entropy'
,
'ChamferDistance'
,
'chamfer_distance'
]
mmdet3d/models/losses/chamfer_distance.py
0 → 100644
View file @
f27d308f
import
torch
import
torch.nn
as
nn
from
torch.nn.functional
import
l1_loss
,
mse_loss
,
smooth_l1_loss
from
mmdet.models.builder
import
LOSSES
def
chamfer_distance
(
src
,
dst
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
criterion_mode
=
'l2'
,
reduction
=
'mean'
):
"""Calculate Chamfer Distance of two sets.
Args:
src (tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
dst (tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (tensor or float): Weight of source loss.
dst_weight (tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
Returns:
tuple: Source and Destination loss with indices.
- loss_src (Tensor): The min distance from source to destination.
- loss_dst (Tensor): The min distance from destination to source.
- indices1 (Tensor): Index the min distance point for each point
in source to destination.
- indices2 (Tensor): Index the min distance point for each point
in destination to source.
"""
if
criterion_mode
==
'smooth_l1'
:
criterion
=
smooth_l1_loss
elif
criterion_mode
==
'l1'
:
criterion
=
l1_loss
elif
criterion_mode
==
'l2'
:
criterion
=
mse_loss
else
:
raise
NotImplementedError
src_expand
=
src
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
dst
.
shape
[
1
],
1
)
dst_expand
=
dst
.
unsqueeze
(
1
).
repeat
(
1
,
src
.
shape
[
1
],
1
,
1
)
distance
=
criterion
(
src_expand
,
dst_expand
,
reduction
=
'none'
).
sum
(
-
1
)
src2dst_distance
,
indices1
=
torch
.
min
(
distance
,
dim
=
2
)
# (B,N)
dst2src_distance
,
indices2
=
torch
.
min
(
distance
,
dim
=
1
)
# (B,M)
loss_src
=
(
src2dst_distance
*
src_weight
)
loss_dst
=
(
dst2src_distance
*
dst_weight
)
if
reduction
==
'sum'
:
loss_src
=
torch
.
sum
(
loss_src
)
loss_dst
=
torch
.
sum
(
loss_dst
)
elif
reduction
==
'mean'
:
loss_src
=
torch
.
mean
(
loss_src
)
loss_dst
=
torch
.
mean
(
loss_dst
)
elif
reduction
==
'none'
:
pass
else
:
raise
NotImplementedError
return
loss_src
,
loss_dst
,
indices1
,
indices2
@
LOSSES
.
register_module
()
class
ChamferDistance
(
nn
.
Module
):
"""Calculate Chamfer Distance of two sets.
Args:
mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_src_weight (float): Weight of loss_source.
loss_dst_weight (float): Weight of loss_target.
"""
def
__init__
(
self
,
mode
=
'l2'
,
reduction
=
'mean'
,
loss_src_weight
=
1.0
,
loss_dst_weight
=
1.0
):
super
(
ChamferDistance
,
self
).
__init__
()
assert
mode
in
[
'smooth_l1'
,
'l1'
,
'l2'
]
assert
reduction
in
[
'none'
,
'sum'
,
'mean'
]
self
.
mode
=
mode
self
.
reduction
=
reduction
self
.
loss_src_weight
=
loss_src_weight
self
.
loss_dst_weight
=
loss_dst_weight
def
forward
(
self
,
source
,
target
,
src_weight
=
1.0
,
dst_weight
=
1.0
,
reduction_override
=
None
,
return_indices
=
False
,
**
kwargs
):
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
loss_source
,
loss_target
,
indices1
,
indices2
=
chamfer_distance
(
source
,
target
,
src_weight
,
dst_weight
,
self
.
mode
,
reduction
)
loss_source
*=
self
.
loss_src_weight
loss_target
*=
self
.
loss_dst_weight
if
return_indices
:
return
loss_source
,
loss_target
,
indices1
,
indices2
else
:
return
loss_source
,
loss_target
mmdet3d/models/model_utils/__init__.py
0 → 100644
View file @
f27d308f
from
.vote_module
import
VoteModule
__all__
=
[
'VoteModule'
]
mmdet3d/
op
s/vote_module.py
→
mmdet3d/
models/model_util
s/vote_module.py
View file @
f27d308f
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
torch.nn.functional
import
l1_loss
,
mse_loss
,
smooth_l1_loss
from
mmdet3d.models.builder
import
build_loss
class
VoteModule
(
nn
.
Module
):
...
...
@@ -22,7 +23,7 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features.
Default: True.
loss_weight (float): Weight
of vot
ing
loss.
vote_loss (dict): config
of vot
e
loss.
"""
def
__init__
(
self
,
...
...
@@ -33,13 +34,13 @@ class VoteModule(nn.Module):
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
norm_feats
=
True
,
loss_weight
=
1.0
):
vote_loss
=
None
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
vote_per_seed
=
vote_per_seed
self
.
gt_per_seed
=
gt_per_seed
self
.
norm_feats
=
norm_feats
self
.
loss_weight
=
loss_weight
self
.
vote_loss
=
build_loss
(
vote_loss
)
prev_channels
=
in_channels
vote_conv_list
=
list
()
...
...
@@ -118,57 +119,17 @@ class VoteModule(nn.Module):
seed_gt_votes_mask
=
torch
.
gather
(
vote_targets_mask
,
1
,
seed_indices
).
float
()
pos_num
=
torch
.
sum
(
seed_gt_votes_mask
)
seed_indices_expand
=
seed_indices
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
3
*
self
.
gt_per_seed
)
seed_gt_votes
=
torch
.
gather
(
vote_targets
,
1
,
seed_indices_expand
)
seed_gt_votes
+=
seed_points
.
repeat
(
1
,
1
,
3
)
distance
=
self
.
nn_distance
(
weight
=
seed_gt_votes_mask
/
(
torch
.
sum
(
seed_gt_votes_mask
)
+
1e-6
)
distance
=
self
.
vote_loss
(
vote_points
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
seed_gt_votes
.
view
(
batch_size
*
num_seed
,
-
1
,
3
),
mode
=
'l1'
)[
2
]
votes_distance
=
torch
.
min
(
distance
,
dim
=
1
)[
0
]
votes_dist
=
votes_distance
.
view
(
batch_size
,
num_seed
)
vote_loss
=
torch
.
sum
(
votes_dist
*
seed_gt_votes_mask
)
/
(
pos_num
+
1e-6
)
return
self
.
loss_weight
*
vote_loss
dst_weight
=
weight
.
view
(
batch_size
*
num_seed
,
1
))[
1
]
vote_loss
=
torch
.
sum
(
torch
.
min
(
distance
,
dim
=
1
)[
0
])
def
nn_distance
(
self
,
points1
,
points2
,
mode
=
'smooth_l1'
):
"""Find the nearest neighbor from point1 to point2
Args:
points1 (Tensor): points to find the Nearest neighbor.
points2 (Tensor): points to find the Nearest neighbor.
mode (str): Specify the function (smooth_l1, l1 or l2)
to calculate distance.
Returns:
tuple[Tensor]:
- distance1: the nearest distance from points1 to points2.
- index1: the index of the nearest neighbor for points1.
- distance2: the nearest distance from points2 to points1.
- index2: the index of the nearest neighbor for points2.
"""
assert
mode
in
[
'smooth_l1'
,
'l1'
,
'l2'
]
N
=
points1
.
shape
[
1
]
M
=
points2
.
shape
[
1
]
pc1_expand_tile
=
points1
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
M
,
1
)
pc2_expand_tile
=
points2
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
,
1
)
if
mode
==
'smooth_l1'
:
pc_dist
=
torch
.
sum
(
smooth_l1_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
elif
mode
==
'l1'
:
pc_dist
=
torch
.
sum
(
l1_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
# (B,N,M)
elif
mode
==
'l2'
:
pc_dist
=
torch
.
sum
(
mse_loss
(
pc1_expand_tile
,
pc2_expand_tile
),
dim
=-
1
)
# (B,N,M)
else
:
raise
NotImplementedError
distance1
,
index1
=
torch
.
min
(
pc_dist
,
dim
=
2
)
# (B,N)
distance2
,
index2
=
torch
.
min
(
pc_dist
,
dim
=
1
)
# (B,M)
return
distance1
,
index1
,
distance2
,
index2
return
vote_loss
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
View file @
f27d308f
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
build_norm_layer
,
normal_init
,
xavier_init
from
mmcv.cnn
import
ConvModule
,
normal_init
,
xavier_init
import
mmdet3d.ops.spconv
as
spconv
from
mmdet3d.core
import
build_bbox_coder
,
multi_apply
from
mmdet3d.core.bbox
import
box_torch_ops
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops
import
make_sparse_convmodule
from
mmdet3d.ops.iou3d.iou3d_utils
import
(
boxes3d_to_bev_torch_lidar
,
nms_gpu
,
nms_normal_gpu
)
from
mmdet.models
import
HEADS
...
...
@@ -78,19 +79,18 @@ class PartA2BboxHead(nn.Module):
assert
down_conv_channels
[
-
1
]
==
shared_fc_channels
[
0
]
# init layers
block
=
self
.
post_act_block
part_channel_last
=
part_in_channels
part_conv
=
[]
for
i
,
channel
in
enumerate
(
part_conv_channels
):
part_conv
.
append
(
block
(
make_sparse_convmodule
(
part_channel_last
,
channel
,
3
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
indice_key
=
f
'rcnn_part
{
i
}
'
))
indice_key
=
f
'rcnn_part
{
i
}
'
,
conv_type
=
'SubMConv3d'
))
part_channel_last
=
channel
self
.
part_conv
=
spconv
.
SparseSequential
(
*
part_conv
)
...
...
@@ -98,13 +98,14 @@ class PartA2BboxHead(nn.Module):
seg_conv
=
[]
for
i
,
channel
in
enumerate
(
seg_conv_channels
):
seg_conv
.
append
(
block
(
make_sparse_convmodule
(
seg_channel_last
,
channel
,
3
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
indice_key
=
f
'rcnn_seg
{
i
}
'
))
indice_key
=
f
'rcnn_seg
{
i
}
'
,
conv_type
=
'SubMConv3d'
))
seg_channel_last
=
channel
self
.
seg_conv
=
spconv
.
SparseSequential
(
*
seg_conv
)
...
...
@@ -114,26 +115,28 @@ class PartA2BboxHead(nn.Module):
merge_conv
=
[]
for
i
,
channel
in
enumerate
(
merge_conv_channels
):
merge_conv
.
append
(
block
(
make_sparse_convmodule
(
merge_conv_channel_last
,
channel
,
3
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
indice_key
=
f
'rcnn_down0'
))
indice_key
=
f
'rcnn_down0'
,
conv_type
=
'SubMConv3d'
))
merge_conv_channel_last
=
channel
down_conv_channel_last
=
merge_conv_channel_last
conv_down
=
[]
for
i
,
channel
in
enumerate
(
down_conv_channels
):
conv_down
.
append
(
block
(
make_sparse_convmodule
(
down_conv_channel_last
,
channel
,
3
,
padding
=
1
,
norm_cfg
=
norm_cfg
,
indice_key
=
f
'rcnn_down1'
))
indice_key
=
f
'rcnn_down1'
,
conv_type
=
'SubMConv3d'
))
down_conv_channel_last
=
channel
self
.
conv_down
.
add_module
(
'merge_conv'
,
...
...
@@ -228,69 +231,6 @@ class PartA2BboxHead(nn.Module):
normal_init
(
self
.
conv_reg
[
-
1
].
conv
,
mean
=
0
,
std
=
0.001
)
def
post_act_block
(
self
,
in_channels
,
out_channels
,
kernel_size
,
indice_key
,
stride
=
1
,
padding
=
0
,
conv_type
=
'subm'
,
norm_cfg
=
None
):
"""Make post activate sparse convolution block.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
norm_cfg (dict[str]): config of normalization layer
Returns:
spconv.SparseSequential: post activate sparse convolution block.
"""
# TODO: clean post_act_block by existing bottlnecks.
assert
conv_type
in
[
'subm'
,
'spconv'
,
'inverseconv'
]
if
conv_type
==
'subm'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SubMConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
elif
conv_type
==
'spconv'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SparseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
elif
conv_type
==
'inverseconv'
:
m
=
spconv
.
SparseSequential
(
spconv
.
SparseInverseConv3d
(
in_channels
,
out_channels
,
kernel_size
,
bias
=
False
,
indice_key
=
indice_key
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
nn
.
ReLU
(
inplace
=
True
))
else
:
raise
NotImplementedError
return
m
def
forward
(
self
,
seg_feats
,
part_feats
):
# (B * N, out_x, out_y, out_z, 4)
rcnn_batch_size
=
part_feats
.
shape
[
0
]
...
...
mmdet3d/ops/__init__.py
View file @
f27d308f
...
...
@@ -9,11 +9,10 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
from
.interpolate
import
three_interpolate
,
three_nn
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.pointnet_modules
import
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_
cpu
,
points_in_boxes_gpu
)
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_
batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
make_sparse_convmodule
)
from
.vote_module
import
VoteModule
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
__all__
=
[
...
...
@@ -26,5 +25,5 @@ __all__ = [
'make_sparse_convmodule'
,
'ball_query'
,
'furthest_point_sample'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'group_points'
,
'GroupAll'
,
'QueryAndGroup'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
,
'
VoteModule
'
'PointSAModuleMSG'
,
'PointFPModule'
,
'
points_in_boxes_batch
'
]
mmdet3d/ops/ball_query/src/ball_query.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int
ball_query_wrapper
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
);
void
ball_query_kernel_launcher
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
,
cudaStream_t
stream
);
int
ball_query_wrapper
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
new_xyz_tensor
);
CHECK_INPUT
(
xyz_tensor
);
const
float
*
new_xyz
=
new_xyz_tensor
.
data
<
float
>
();
const
float
*
xyz
=
xyz_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
ball_query_kernel_launcher
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
,
stream
);
return
1
;
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
new_xyz_tensor
);
CHECK_INPUT
(
xyz_tensor
);
const
float
*
new_xyz
=
new_xyz_tensor
.
data_ptr
<
float
>
();
const
float
*
xyz
=
xyz_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
ball_query_kernel_launcher
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper
,
"ball_query_wrapper"
);
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper
,
"ball_query_wrapper"
);
}
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
View file @
f27d308f
...
...
@@ -3,65 +3,70 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
ball_query_kernel
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
__restrict__
new_xyz
,
const
float
*
__restrict__
xyz
,
int
*
__restrict__
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
m
)
return
;
__global__
void
ball_query_kernel
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
__restrict__
new_xyz
,
const
float
*
__restrict__
xyz
,
int
*
__restrict__
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
m
)
return
;
new_xyz
+=
bs_idx
*
m
*
3
+
pt_idx
*
3
;
xyz
+=
bs_idx
*
n
*
3
;
idx
+=
bs_idx
*
m
*
nsample
+
pt_idx
*
nsample
;
new_xyz
+=
bs_idx
*
m
*
3
+
pt_idx
*
3
;
xyz
+=
bs_idx
*
n
*
3
;
idx
+=
bs_idx
*
m
*
nsample
+
pt_idx
*
nsample
;
float
radius2
=
radius
*
radius
;
float
new_x
=
new_xyz
[
0
];
float
new_y
=
new_xyz
[
1
];
float
new_z
=
new_xyz
[
2
];
float
radius2
=
radius
*
radius
;
float
new_x
=
new_xyz
[
0
];
float
new_y
=
new_xyz
[
1
];
float
new_z
=
new_xyz
[
2
];
int
cnt
=
0
;
for
(
int
k
=
0
;
k
<
n
;
++
k
)
{
float
x
=
xyz
[
k
*
3
+
0
];
float
y
=
xyz
[
k
*
3
+
1
];
float
z
=
xyz
[
k
*
3
+
2
];
float
d2
=
(
new_x
-
x
)
*
(
new_x
-
x
)
+
(
new_y
-
y
)
*
(
new_y
-
y
)
+
(
new_z
-
z
)
*
(
new_z
-
z
);
if
(
d2
<
radius2
){
if
(
cnt
==
0
){
for
(
int
l
=
0
;
l
<
nsample
;
++
l
)
{
idx
[
l
]
=
k
;
}
}
idx
[
cnt
]
=
k
;
++
cnt
;
if
(
cnt
>=
nsample
)
break
;
int
cnt
=
0
;
for
(
int
k
=
0
;
k
<
n
;
++
k
)
{
float
x
=
xyz
[
k
*
3
+
0
];
float
y
=
xyz
[
k
*
3
+
1
];
float
z
=
xyz
[
k
*
3
+
2
];
float
d2
=
(
new_x
-
x
)
*
(
new_x
-
x
)
+
(
new_y
-
y
)
*
(
new_y
-
y
)
+
(
new_z
-
z
)
*
(
new_z
-
z
);
if
(
d2
<
radius2
)
{
if
(
cnt
==
0
)
{
for
(
int
l
=
0
;
l
<
nsample
;
++
l
)
{
idx
[
l
]
=
k
;
}
}
idx
[
cnt
]
=
k
;
++
cnt
;
if
(
cnt
>=
nsample
)
break
;
}
}
}
void
ball_query_kernel_launcher
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
,
cudaStream_t
stream
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
void
ball_query_kernel_launcher
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
\
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
,
cudaStream_t
stream
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t
err
;
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
ball_query_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
ball_query_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
return
1
;
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
float
*
temp
=
temp_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"furthest_point_sampling_wrapper"
,
&
furthest_point_sampling_wrapper
,
"furthest_point_sampling_wrapper"
);
m
.
def
(
"furthest_point_sampling_wrapper"
,
&
furthest_point_sampling_wrapper
,
"furthest_point_sampling_wrapper"
);
}
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cuda.cu
View file @
f27d308f
...
...
@@ -3,179 +3,204 @@
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
inline
int
opt_n_threads
(
int
work_size
)
{
const
int
pow_2
=
std
::
log
(
static_cast
<
double
>
(
work_size
))
/
std
::
log
(
2.0
);
const
int
pow_2
=
std
::
log
(
static_cast
<
double
>
(
work_size
))
/
std
::
log
(
2.0
);
return
max
(
min
(
1
<<
pow_2
,
TOTAL_THREADS
),
1
);
return
max
(
min
(
1
<<
pow_2
,
TOTAL_THREADS
),
1
);
}
__device__
void
__update
(
float
*
__restrict__
dists
,
int
*
__restrict__
dists_i
,
int
idx1
,
int
idx2
){
const
float
v1
=
dists
[
idx1
],
v2
=
dists
[
idx2
];
const
int
i1
=
dists_i
[
idx1
],
i2
=
dists_i
[
idx2
];
dists
[
idx1
]
=
max
(
v1
,
v2
);
dists_i
[
idx1
]
=
v2
>
v1
?
i2
:
i1
;
__device__
void
__update
(
float
*
__restrict__
dists
,
int
*
__restrict__
dists_i
,
int
idx1
,
int
idx2
)
{
const
float
v1
=
dists
[
idx1
],
v2
=
dists
[
idx2
];
const
int
i1
=
dists_i
[
idx1
],
i2
=
dists_i
[
idx2
];
dists
[
idx1
]
=
max
(
v1
,
v2
);
dists_i
[
idx1
]
=
v2
>
v1
?
i2
:
i1
;
}
template
<
unsigned
int
block_size
>
__global__
void
furthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, 3)
//
tmp
: (B, N)
//
output:
//
idx: (B, M)
if
(
m
<=
0
)
return
;
__shared__
float
dists
[
block_size
]
;
__shared__
in
t
dists
_i
[
block_size
];
int
batch_index
=
blockIdx
.
x
;
dataset
+=
batch_index
*
n
*
3
;
temp
+=
batch_index
*
n
;
idxs
+=
batch_index
*
m
;
int
tid
=
threadIdx
.
x
;
const
int
str
id
e
=
block_size
;
int
old
=
0
;
i
f
(
threadIdx
.
x
=
=
0
)
idxs
[
0
]
=
old
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
__global__
void
furthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
//
dataset
: (B, N
, 3
)
//
tmp: (B, N)
//
output:
// idx: (B, M)
if
(
m
<=
0
)
return
;
__shared__
floa
t
dists
[
block_size
];
__shared__
int
dists_i
[
block_size
];
int
batch_index
=
blockIdx
.
x
;
dataset
+=
batch_index
*
n
*
3
;
temp
+=
batch_index
*
n
;
idxs
+=
batch_index
*
m
;
int
t
id
=
threadIdx
.
x
;
const
int
stride
=
block_size
;
i
nt
old
=
0
;
if
(
threadIdx
.
x
==
0
)
idxs
[
0
]
=
old
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
int
besti
=
0
;
float
best
=
-
1
;
float
x1
=
dataset
[
old
*
3
+
0
];
float
y1
=
dataset
[
old
*
3
+
1
];
float
z1
=
dataset
[
old
*
3
+
2
];
for
(
int
k
=
tid
;
k
<
n
;
k
+=
stride
)
{
float
x2
,
y2
,
z2
;
x2
=
dataset
[
k
*
3
+
0
];
y2
=
dataset
[
k
*
3
+
1
];
z2
=
dataset
[
k
*
3
+
2
];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float
d
=
(
x2
-
x1
)
*
(
x2
-
x1
)
+
(
y2
-
y1
)
*
(
y2
-
y1
)
+
(
z2
-
z1
)
*
(
z2
-
z1
);
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
float
x2
,
y2
,
z2
;
x2
=
dataset
[
k
*
3
+
0
];
y2
=
dataset
[
k
*
3
+
1
];
z2
=
dataset
[
k
*
3
+
2
];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float
d
=
(
x2
-
x1
)
*
(
x2
-
x1
)
+
(
y2
-
y1
)
*
(
y2
-
y1
)
+
(
z2
-
z1
)
*
(
z2
-
z1
);
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
}
dists
[
tid
]
=
best
;
dists_i
[
tid
]
=
besti
;
__syncthreads
();
if
(
block_size
>=
1024
)
{
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
}
if
(
block_size
>=
512
)
{
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
}
if
(
block_size
>=
256
)
{
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
}
if
(
block_size
>=
128
)
{
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
}
if
(
block_size
>=
64
)
{
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
}
if
(
block_size
>=
32
)
{
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
}
if
(
block_size
>=
16
)
{
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
}
if
(
block_size
>=
8
)
{
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
}
if
(
block_size
>=
4
)
{
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
}
if
(
block_size
>=
2
)
{
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
}
old
=
dists_i
[
0
];
if
(
tid
==
0
)
idxs
[
j
]
=
old
;
}
if
(
tid
==
0
)
idxs
[
j
]
=
old
;
}
}
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
n
);
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
n
);
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/gather_points/src/gather_points.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
int
gather_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
gather_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
gather_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
gather_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
int
gather_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
){
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
,
stream
);
return
1
;
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
float
*
out
=
out_tensor
.
data_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
,
stream
);
return
1
;
}
int
gather_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
float
*
grad_points
=
grad_points_tensor
.
data_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gather_points_wrapper"
,
&
gather_points_wrapper
,
"gather_points_wrapper"
);
m
.
def
(
"gather_points_grad_wrapper"
,
&
gather_points_grad_wrapper
,
"gather_points_grad_wrapper"
);
m
.
def
(
"gather_points_wrapper"
,
&
gather_points_wrapper
,
"gather_points_wrapper"
);
m
.
def
(
"gather_points_grad_wrapper"
,
&
gather_points_grad_wrapper
,
"gather_points_grad_wrapper"
);
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment