Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
333536f6
Unverified
Commit
333536f6
authored
Apr 06, 2022
by
Wenwei Zhang
Committed by
GitHub
Apr 06, 2022
Browse files
Release v1.0.0rc1
parents
9c7270d0
f747daab
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
54 additions
and
968 deletions
+54
-968
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
+16
-17
mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py
+2
-1
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
+2
-1
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
...els/roi_heads/roi_extractors/single_roiaware_extractor.py
+1
-1
mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py
...els/roi_heads/roi_extractors/single_roipoint_extractor.py
+1
-1
mmdet3d/models/voxel_encoders/__init__.py
mmdet3d/models/voxel_encoders/__init__.py
+3
-3
mmdet3d/models/voxel_encoders/pillar_encoder.py
mmdet3d/models/voxel_encoders/pillar_encoder.py
+7
-3
mmdet3d/models/voxel_encoders/voxel_encoder.py
mmdet3d/models/voxel_encoders/voxel_encoder.py
+1
-1
mmdet3d/ops/__init__.py
mmdet3d/ops/__init__.py
+20
-16
mmdet3d/ops/ball_query/__init__.py
mmdet3d/ops/ball_query/__init__.py
+0
-4
mmdet3d/ops/ball_query/ball_query.py
mmdet3d/ops/ball_query/ball_query.py
+0
-48
mmdet3d/ops/ball_query/src/ball_query.cpp
mmdet3d/ops/ball_query/src/ball_query.cpp
+0
-47
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
+0
-78
mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py
mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py
+1
-2
mmdet3d/ops/furthest_point_sample/__init__.py
mmdet3d/ops/furthest_point_sample/__init__.py
+0
-9
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
+0
-79
mmdet3d/ops/furthest_point_sample/points_sampler.py
mmdet3d/ops/furthest_point_sample/points_sampler.py
+0
-160
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp
...d/ops/furthest_point_sample/src/furthest_point_sample.cpp
+0
-65
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cuda.cu
...s/furthest_point_sample/src/furthest_point_sample_cuda.cu
+0
-400
mmdet3d/ops/furthest_point_sample/utils.py
mmdet3d/ops/furthest_point_sample/utils.py
+0
-32
No files found.
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
View file @
333536f6
...
...
@@ -2,6 +2,9 @@
import
numpy
as
np
import
torch
from
mmcv.cnn
import
ConvModule
,
normal_init
from
mmcv.ops
import
SparseConvTensor
,
SparseMaxPool3d
,
SparseSequential
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmcv.runner
import
BaseModule
from
torch
import
nn
as
nn
...
...
@@ -9,8 +12,6 @@ from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis
,
xywhr2xyxyr
)
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops
import
make_sparse_convmodule
from
mmdet3d.ops
import
spconv
as
spconv
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
...
...
@@ -95,7 +96,7 @@ class PartA2BboxHead(BaseModule):
indice_key
=
f
'rcnn_part
{
i
}
'
,
conv_type
=
'SubMConv3d'
))
part_channel_last
=
channel
self
.
part_conv
=
spconv
.
SparseSequential
(
*
part_conv
)
self
.
part_conv
=
SparseSequential
(
*
part_conv
)
seg_channel_last
=
seg_in_channels
seg_conv
=
[]
...
...
@@ -110,9 +111,9 @@ class PartA2BboxHead(BaseModule):
indice_key
=
f
'rcnn_seg
{
i
}
'
,
conv_type
=
'SubMConv3d'
))
seg_channel_last
=
channel
self
.
seg_conv
=
spconv
.
SparseSequential
(
*
seg_conv
)
self
.
seg_conv
=
SparseSequential
(
*
seg_conv
)
self
.
conv_down
=
spconv
.
SparseSequential
()
self
.
conv_down
=
SparseSequential
()
merge_conv_channel_last
=
part_channel_last
+
seg_channel_last
merge_conv
=
[]
...
...
@@ -140,12 +141,10 @@ class PartA2BboxHead(BaseModule):
indice_key
=
'rcnn_down1'
))
down_conv_channel_last
=
channel
self
.
conv_down
.
add_module
(
'merge_conv'
,
spconv
.
SparseSequential
(
*
merge_conv
))
self
.
conv_down
.
add_module
(
'max_pool3d'
,
spconv
.
SparseMaxPool3d
(
kernel_size
=
2
,
stride
=
2
))
self
.
conv_down
.
add_module
(
'down_conv'
,
spconv
.
SparseSequential
(
*
conv_down
))
self
.
conv_down
.
add_module
(
'merge_conv'
,
SparseSequential
(
*
merge_conv
))
self
.
conv_down
.
add_module
(
'max_pool3d'
,
SparseMaxPool3d
(
kernel_size
=
2
,
stride
=
2
))
self
.
conv_down
.
add_module
(
'down_conv'
,
SparseSequential
(
*
conv_down
))
shared_fc_list
=
[]
pool_size
=
roi_feat_size
//
2
...
...
@@ -256,10 +255,10 @@ class PartA2BboxHead(BaseModule):
seg_features
=
seg_feats
[
sparse_idx
[:,
0
],
sparse_idx
[:,
1
],
sparse_idx
[:,
2
],
sparse_idx
[:,
3
]]
coords
=
sparse_idx
.
int
()
part_features
=
spconv
.
SparseConvTensor
(
part_features
,
coords
,
sparse_shape
,
rcnn_batch_size
)
seg_features
=
spconv
.
SparseConvTensor
(
seg_features
,
coords
,
sparse_shape
,
rcnn_batch_size
)
part_features
=
SparseConvTensor
(
part_features
,
coords
,
sparse_shape
,
rcnn_batch_size
)
seg_features
=
SparseConvTensor
(
seg_features
,
coords
,
sparse_shape
,
rcnn_batch_size
)
# forward rcnn network
x_part
=
self
.
part_conv
(
part_features
)
...
...
@@ -267,8 +266,8 @@ class PartA2BboxHead(BaseModule):
merged_feature
=
torch
.
cat
((
x_rpn
.
features
,
x_part
.
features
),
dim
=
1
)
# (N, C)
shared_feature
=
spconv
.
SparseConvTensor
(
merged_feature
,
coords
,
sparse_shape
,
rcnn_batch_size
)
shared_feature
=
SparseConvTensor
(
merged_feature
,
coords
,
sparse_shape
,
rcnn_batch_size
)
x
=
self
.
conv_down
(
shared_feature
)
...
...
mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py
View file @
333536f6
...
...
@@ -3,6 +3,8 @@ import numpy as np
import
torch
from
mmcv.cnn
import
ConvModule
,
normal_init
from
mmcv.cnn.bricks
import
build_conv_layer
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmcv.runner
import
BaseModule
from
torch
import
nn
as
nn
...
...
@@ -10,7 +12,6 @@ from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis
,
xywhr2xyxyr
)
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops
import
build_sa_module
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
...
...
mmdet3d/models/roi_heads/mask_heads/primitive_head.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.ops
import
furthest_point_sample
from
mmcv.runner
import
BaseModule
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.models.model_utils
import
VoteModule
from
mmdet3d.ops
import
build_sa_module
,
furthest_point_sample
from
mmdet3d.ops
import
build_sa_module
from
mmdet.core
import
multi_apply
from
mmdet.models
import
HEADS
...
...
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv
import
ops
from
mmcv.runner
import
BaseModule
from
mmdet3d
import
ops
from
mmdet.models.builder
import
ROI_EXTRACTORS
...
...
mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv
import
ops
from
torch
import
nn
as
nn
from
mmdet3d
import
ops
from
mmdet3d.core.bbox.structures
import
rotation_3d_in_axis
from
mmdet.models.builder
import
ROI_EXTRACTORS
...
...
mmdet3d/models/voxel_encoders/__init__.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
from
.pillar_encoder
import
PillarFeatureNet
from
.pillar_encoder
import
DynamicPillarFeatureNet
,
PillarFeatureNet
from
.voxel_encoder
import
DynamicSimpleVFE
,
DynamicVFE
,
HardSimpleVFE
,
HardVFE
__all__
=
[
'PillarFeatureNet'
,
'HardVFE'
,
'DynamicVFE'
,
'HardSimpleVFE'
,
'DynamicSimpleVFE'
'PillarFeatureNet'
,
'DynamicPillarFeatureNet'
,
'HardVFE'
,
'DynamicVFE'
,
'HardSimpleVFE'
,
'DynamicSimpleVFE'
]
mmdet3d/models/voxel_encoders/pillar_encoder.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.cnn
import
build_norm_layer
from
mmcv.ops
import
DynamicScatter
from
mmcv.runner
import
force_fp32
from
torch
import
nn
from
mmdet3d.ops
import
DynamicScatter
from
..builder
import
VOXEL_ENCODERS
from
.utils
import
PFNLayer
,
get_paddings_indicator
...
...
@@ -15,6 +15,7 @@ class PillarFeatureNet(nn.Module):
The network prepares the pillar features and performs forward pass
through PFNLayers.
Args:
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
...
...
@@ -98,6 +99,7 @@ class PillarFeatureNet(nn.Module):
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.
Returns:
torch.Tensor: Features of pillars.
"""
...
...
@@ -237,7 +239,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
Args:
pts_coors (torch.Tensor): The coordinates of each points, shape
(M, 3), where M is the number of points.
voxel_mean (torch.Tensor): The mean or aggre
a
gated features of a
voxel_mean (torch.Tensor): The mean or aggregated features of a
voxel, shape (N, C), where N is the number of voxels.
voxel_coors (torch.Tensor): The coordinates of each voxel.
...
...
@@ -294,11 +296,13 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
# Find distance of x, y, and z from pillar center
if
self
.
_with_voxel_center
:
f_center
=
features
.
new_zeros
(
size
=
(
features
.
size
(
0
),
2
))
f_center
=
features
.
new_zeros
(
size
=
(
features
.
size
(
0
),
3
))
f_center
[:,
0
]
=
features
[:,
0
]
-
(
coors
[:,
3
].
type_as
(
features
)
*
self
.
vx
+
self
.
x_offset
)
f_center
[:,
1
]
=
features
[:,
1
]
-
(
coors
[:,
2
].
type_as
(
features
)
*
self
.
vy
+
self
.
y_offset
)
f_center
[:,
2
]
=
features
[:,
2
]
-
(
coors
[:,
1
].
type_as
(
features
)
*
self
.
vz
+
self
.
z_offset
)
features_ls
.
append
(
f_center
)
if
self
.
_with_distance
:
...
...
mmdet3d/models/voxel_encoders/voxel_encoder.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.cnn
import
build_norm_layer
from
mmcv.ops
import
DynamicScatter
from
mmcv.runner
import
force_fp32
from
torch
import
nn
from
mmdet3d.ops
import
DynamicScatter
from
..
import
builder
from
..builder
import
VOXEL_ENCODERS
from
.utils
import
VFELayer
,
get_paddings_indicator
...
...
mmdet3d/ops/__init__.py
View file @
333536f6
...
...
@@ -2,28 +2,32 @@
from
mmcv.ops
import
(
RoIAlign
,
SigmoidFocalLoss
,
get_compiler_version
,
get_compiling_cuda_version
,
nms
,
roi_align
,
sigmoid_focal_loss
)
from
mmcv.ops.assign_score_withk
import
assign_score_withk
from
mmcv.ops.ball_query
import
ball_query
from
mmcv.ops.furthest_point_sample
import
(
furthest_point_sample
,
furthest_point_sample_with_dist
)
from
mmcv.ops.gather_points
import
gather_points
from
mmcv.ops.group_points
import
GroupAll
,
QueryAndGroup
,
grouping_operation
from
mmcv.ops.knn
import
knn
from
mmcv.ops.points_in_boxes
import
(
points_in_boxes_all
,
points_in_boxes_cpu
,
points_in_boxes_part
)
from
mmcv.ops.points_sampler
import
PointsSampler
as
Points_Sampler
from
mmcv.ops.roiaware_pool3d
import
RoIAwarePool3d
from
mmcv.ops.roipoint_pool3d
import
RoIPointPool3d
from
mmcv.ops.scatter_points
import
DynamicScatter
,
dynamic_scatter
from
mmcv.ops.three_interpolate
import
three_interpolate
from
mmcv.ops.three_nn
import
three_nn
from
mmcv.ops.voxelize
import
Voxelization
,
voxelization
from
.ball_query
import
ball_query
from
.dgcnn_modules
import
DGCNNFAModule
,
DGCNNFPModule
,
DGCNNGFModule
from
.furthest_point_sample
import
(
Points_Sampler
,
furthest_point_sample
,
furthest_point_sample_with_dist
)
from
.gather_points
import
gather_points
from
.group_points
import
(
GroupAll
,
QueryAndGroup
,
group_points
,
grouping_operation
)
from
.interpolate
import
three_interpolate
,
three_nn
from
.knn
import
knn
from
.norm
import
NaiveSyncBatchNorm1d
,
NaiveSyncBatchNorm2d
from
.paconv
import
PAConv
,
PAConvCUDA
,
assign_score_withk
from
.paconv
import
PAConv
,
PAConvCUDA
from
.pointnet_modules
import
(
PAConvCUDASAModule
,
PAConvCUDASAModuleMSG
,
PAConvSAModule
,
PAConvSAModuleMSG
,
PointFPModule
,
PointSAModule
,
PointSAModuleMSG
,
build_sa_module
)
from
.roiaware_pool3d
import
(
RoIAwarePool3d
,
points_in_boxes_all
,
points_in_boxes_cpu
,
points_in_boxes_part
)
from
.roipoint_pool3d
import
RoIPointPool3d
from
.sparse_block
import
(
SparseBasicBlock
,
SparseBottleneck
,
make_sparse_convmodule
)
from
.voxel
import
DynamicScatter
,
Voxelization
,
dynamic_scatter
,
voxelization
__all__
=
[
'nms'
,
'soft_nms'
,
'RoIAlign'
,
'roi_align'
,
'get_compiler_version'
,
...
...
@@ -34,9 +38,9 @@ __all__ = [
'RoIAwarePool3d'
,
'points_in_boxes_part'
,
'points_in_boxes_cpu'
,
'make_sparse_convmodule'
,
'ball_query'
,
'knn'
,
'furthest_point_sample'
,
'furthest_point_sample_with_dist'
,
'three_interpolate'
,
'three_nn'
,
'gather_points'
,
'grouping_operation'
,
'
g
roup
_points'
,
'
Group
All
'
,
'QueryAndGroup'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
,
'DGCNNFPModule'
,
'DGCNNGFModule'
,
'DGCNNFAModule'
,
'points_in_boxes_all'
,
'gather_points'
,
'grouping_operation'
,
'
G
roup
All'
,
'QueryAnd
Group'
,
'PointSAModule'
,
'PointSAModuleMSG'
,
'PointFPModule'
,
'DGCNNFPModule'
,
'DGCNNGFModule'
,
'DGCNNFAModule'
,
'points_in_boxes_all'
,
'get_compiler_version'
,
'assign_score_withk'
,
'get_compiling_cuda_version'
,
'Points_Sampler'
,
'build_sa_module'
,
'PAConv'
,
'PAConvCUDA'
,
'PAConvSAModuleMSG'
,
'PAConvSAModule'
,
'PAConvCUDASAModule'
,
...
...
mmdet3d/ops/ball_query/__init__.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
from
.ball_query
import
ball_query
__all__
=
[
'ball_query'
]
mmdet3d/ops/ball_query/ball_query.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.autograd
import
Function
from
.
import
ball_query_ext
class
BallQuery
(
Function
):
"""Ball Query.
Find nearby points in spherical space.
"""
@
staticmethod
def
forward
(
ctx
,
min_radius
:
float
,
max_radius
:
float
,
sample_num
:
int
,
xyz
:
torch
.
Tensor
,
center_xyz
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""forward.
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
Returns:
Tensor: (B, npoint, nsample) tensor with the indices of
the features that form the query balls.
"""
assert
center_xyz
.
is_contiguous
()
assert
xyz
.
is_contiguous
()
assert
min_radius
<
max_radius
B
,
N
,
_
=
xyz
.
size
()
npoint
=
center_xyz
.
size
(
1
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
,
sample_num
).
zero_
()
ball_query_ext
.
ball_query_wrapper
(
B
,
N
,
npoint
,
min_radius
,
max_radius
,
sample_num
,
center_xyz
,
xyz
,
idx
)
ctx
.
mark_non_differentiable
(
idx
)
return
idx
@
staticmethod
def
backward
(
ctx
,
a
=
None
):
return
None
,
None
,
None
,
None
ball_query
=
BallQuery
.
apply
mmdet3d/ops/ball_query/src/ball_query.cpp
deleted
100644 → 0
View file @
9c7270d0
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int
ball_query_wrapper
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
);
void
ball_query_kernel_launcher
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
,
cudaStream_t
stream
);
int
ball_query_wrapper
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
new_xyz_tensor
);
CHECK_INPUT
(
xyz_tensor
);
const
float
*
new_xyz
=
new_xyz_tensor
.
data_ptr
<
float
>
();
const
float
*
xyz
=
xyz_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
ball_query_kernel_launcher
(
b
,
n
,
m
,
min_radius
,
max_radius
,
nsample
,
new_xyz
,
xyz
,
idx
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper
,
"ball_query_wrapper"
);
}
mmdet3d/ops/ball_query/src/ball_query_cuda.cu
deleted
100644 → 0
View file @
9c7270d0
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
ball_query_kernel
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
float
*
__restrict__
new_xyz
,
const
float
*
__restrict__
xyz
,
int
*
__restrict__
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
m
)
return
;
new_xyz
+=
bs_idx
*
m
*
3
+
pt_idx
*
3
;
xyz
+=
bs_idx
*
n
*
3
;
idx
+=
bs_idx
*
m
*
nsample
+
pt_idx
*
nsample
;
float
max_radius2
=
max_radius
*
max_radius
;
float
min_radius2
=
min_radius
*
min_radius
;
float
new_x
=
new_xyz
[
0
];
float
new_y
=
new_xyz
[
1
];
float
new_z
=
new_xyz
[
2
];
int
cnt
=
0
;
for
(
int
k
=
0
;
k
<
n
;
++
k
)
{
float
x
=
xyz
[
k
*
3
+
0
];
float
y
=
xyz
[
k
*
3
+
1
];
float
z
=
xyz
[
k
*
3
+
2
];
float
d2
=
(
new_x
-
x
)
*
(
new_x
-
x
)
+
(
new_y
-
y
)
*
(
new_y
-
y
)
+
(
new_z
-
z
)
*
(
new_z
-
z
);
if
(
d2
==
0
||
(
d2
>=
min_radius2
&&
d2
<
max_radius2
))
{
if
(
cnt
==
0
)
{
for
(
int
l
=
0
;
l
<
nsample
;
++
l
)
{
idx
[
l
]
=
k
;
}
}
idx
[
cnt
]
=
k
;
++
cnt
;
if
(
cnt
>=
nsample
)
break
;
}
}
}
void
ball_query_kernel_launcher
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
,
cudaStream_t
stream
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
ball_query_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
min_radius
,
max_radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.cnn
import
ConvModule
from
mmcv.ops.group_points
import
GroupAll
,
QueryAndGroup
,
grouping_operation
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
..group_points
import
GroupAll
,
QueryAndGroup
,
grouping_operation
class
BaseDGCNNGFModule
(
nn
.
Module
):
"""Base module for point graph feature module used in DGCNN.
...
...
mmdet3d/ops/furthest_point_sample/__init__.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
from
.furthest_point_sample
import
(
furthest_point_sample
,
furthest_point_sample_with_dist
)
from
.points_sampler
import
Points_Sampler
__all__
=
[
'furthest_point_sample'
,
'furthest_point_sample_with_dist'
,
'Points_Sampler'
]
mmdet3d/ops/furthest_point_sample/furthest_point_sample.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.autograd
import
Function
from
.
import
furthest_point_sample_ext
class
FurthestPointSampling
(
Function
):
"""Furthest Point Sampling.
Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance.
"""
@
staticmethod
def
forward
(
ctx
,
points_xyz
:
torch
.
Tensor
,
num_points
:
int
)
->
torch
.
Tensor
:
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) where N > num_points.
num_points (int): Number of points in the sampled set.
Returns:
Tensor: (B, num_points) indices of the sampled points.
"""
assert
points_xyz
.
is_contiguous
()
B
,
N
=
points_xyz
.
size
()[:
2
]
output
=
torch
.
cuda
.
IntTensor
(
B
,
num_points
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
furthest_point_sample_ext
.
furthest_point_sampling_wrapper
(
B
,
N
,
num_points
,
points_xyz
,
temp
,
output
)
ctx
.
mark_non_differentiable
(
output
)
return
output
@
staticmethod
def
backward
(
xyz
,
a
=
None
):
return
None
,
None
class
FurthestPointSamplingWithDist
(
Function
):
"""Furthest Point Sampling With Distance.
Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance.
"""
@
staticmethod
def
forward
(
ctx
,
points_dist
:
torch
.
Tensor
,
num_points
:
int
)
->
torch
.
Tensor
:
"""forward.
Args:
points_dist (Tensor): (B, N, N) Distance between each point pair.
num_points (int): Number of points in the sampled set.
Returns:
Tensor: (B, num_points) indices of the sampled points.
"""
assert
points_dist
.
is_contiguous
()
B
,
N
,
_
=
points_dist
.
size
()
output
=
points_dist
.
new_zeros
([
B
,
num_points
],
dtype
=
torch
.
int32
)
temp
=
points_dist
.
new_zeros
([
B
,
N
]).
fill_
(
1e10
)
furthest_point_sample_ext
.
furthest_point_sampling_with_dist_wrapper
(
B
,
N
,
num_points
,
points_dist
,
temp
,
output
)
ctx
.
mark_non_differentiable
(
output
)
return
output
@
staticmethod
def
backward
(
xyz
,
a
=
None
):
return
None
,
None
furthest_point_sample
=
FurthestPointSampling
.
apply
furthest_point_sample_with_dist
=
FurthestPointSamplingWithDist
.
apply
mmdet3d/ops/furthest_point_sample/points_sampler.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
import
torch
from
mmcv.runner
import
force_fp32
from
torch
import
nn
as
nn
from
.furthest_point_sample
import
(
furthest_point_sample
,
furthest_point_sample_with_dist
)
from
.utils
import
calc_square_dist
def
get_sampler_type
(
sampler_type
):
"""Get the type and mode of points sampler.
Args:
sampler_type (str): The type of points sampler.
The valid value are "D-FPS", "F-FPS", or "FS".
Returns:
class: Points sampler type.
"""
if
sampler_type
==
'D-FPS'
:
sampler
=
DFPS_Sampler
elif
sampler_type
==
'F-FPS'
:
sampler
=
FFPS_Sampler
elif
sampler_type
==
'FS'
:
sampler
=
FS_Sampler
else
:
raise
ValueError
(
'Only "sampler_type" of "D-FPS", "F-FPS", or "FS"'
f
' are supported, got
{
sampler_type
}
'
)
return
sampler
class
Points_Sampler
(
nn
.
Module
):
"""Points sampling.
Args:
num_point (list[int]): Number of sample points.
fps_mod_list (list[str], optional): Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
F-FPS: using feature distances for FPS.
D-FPS: using Euclidean distances of points for FPS.
FS: using F-FPS and D-FPS simultaneously.
fps_sample_range_list (list[int], optional):
Range of points to apply FPS. Default: [-1].
"""
def
__init__
(
self
,
num_point
:
List
[
int
],
fps_mod_list
:
List
[
str
]
=
[
'D-FPS'
],
fps_sample_range_list
:
List
[
int
]
=
[
-
1
]):
super
(
Points_Sampler
,
self
).
__init__
()
# FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to
# fps_mod_list and fps_sample_range_list.
assert
len
(
num_point
)
==
len
(
fps_mod_list
)
==
len
(
fps_sample_range_list
)
self
.
num_point
=
num_point
self
.
fps_sample_range_list
=
fps_sample_range_list
self
.
samplers
=
nn
.
ModuleList
()
for
fps_mod
in
fps_mod_list
:
self
.
samplers
.
append
(
get_sampler_type
(
fps_mod
)())
self
.
fp16_enabled
=
False
@
force_fp32
()
def
forward
(
self
,
points_xyz
,
features
):
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) Descriptors of the features.
Return:
Tensor: (B, npoint, sample_num) Indices of sampled points.
"""
indices
=
[]
last_fps_end_index
=
0
for
fps_sample_range
,
sampler
,
npoint
in
zip
(
self
.
fps_sample_range_list
,
self
.
samplers
,
self
.
num_point
):
assert
fps_sample_range
<
points_xyz
.
shape
[
1
]
if
fps_sample_range
==
-
1
:
sample_points_xyz
=
points_xyz
[:,
last_fps_end_index
:]
sample_features
=
features
[:,
:,
last_fps_end_index
:]
if
\
features
is
not
None
else
None
else
:
sample_points_xyz
=
\
points_xyz
[:,
last_fps_end_index
:
fps_sample_range
]
sample_features
=
\
features
[:,
:,
last_fps_end_index
:
fps_sample_range
]
if
\
features
is
not
None
else
None
fps_idx
=
sampler
(
sample_points_xyz
.
contiguous
(),
sample_features
,
npoint
)
indices
.
append
(
fps_idx
+
last_fps_end_index
)
last_fps_end_index
+=
fps_sample_range
indices
=
torch
.
cat
(
indices
,
dim
=
1
)
return
indices
class
DFPS_Sampler
(
nn
.
Module
):
"""DFPS_Sampling.
Using Euclidean distances of points for FPS.
"""
def
__init__
(
self
):
super
(
DFPS_Sampler
,
self
).
__init__
()
def
forward
(
self
,
points
,
features
,
npoint
):
"""Sampling points with D-FPS."""
fps_idx
=
furthest_point_sample
(
points
.
contiguous
(),
npoint
)
return
fps_idx
class
FFPS_Sampler
(
nn
.
Module
):
"""FFPS_Sampler.
Using feature distances for FPS.
"""
def
__init__
(
self
):
super
(
FFPS_Sampler
,
self
).
__init__
()
def
forward
(
self
,
points
,
features
,
npoint
):
"""Sampling points with F-FPS."""
assert
features
is
not
None
,
\
'feature input to FFPS_Sampler should not be None'
features_for_fps
=
torch
.
cat
([
points
,
features
.
transpose
(
1
,
2
)],
dim
=
2
)
features_dist
=
calc_square_dist
(
features_for_fps
,
features_for_fps
,
norm
=
False
)
fps_idx
=
furthest_point_sample_with_dist
(
features_dist
,
npoint
)
return
fps_idx
class
FS_Sampler
(
nn
.
Module
):
"""FS_Sampling.
Using F-FPS and D-FPS simultaneously.
"""
def
__init__
(
self
):
super
(
FS_Sampler
,
self
).
__init__
()
def
forward
(
self
,
points
,
features
,
npoint
):
"""Sampling points with FS_Sampling."""
assert
features
is
not
None
,
\
'feature input to FS_Sampler should not be None'
features_for_fps
=
torch
.
cat
([
points
,
features
.
transpose
(
1
,
2
)],
dim
=
2
)
features_dist
=
calc_square_dist
(
features_for_fps
,
features_for_fps
,
norm
=
False
)
fps_idx_ffps
=
furthest_point_sample_with_dist
(
features_dist
,
npoint
)
fps_idx_dfps
=
furthest_point_sample
(
points
,
npoint
)
fps_idx
=
torch
.
cat
([
fps_idx_ffps
,
fps_idx_dfps
],
dim
=
1
)
return
fps_idx
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample.cpp
deleted
100644 → 0
View file @
9c7270d0
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
int
furthest_point_sampling_with_dist_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_with_dist_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
float
*
temp
=
temp_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
return
1
;
}
int
furthest_point_sampling_with_dist_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
furthest_point_sampling_with_dist_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"furthest_point_sampling_wrapper"
,
&
furthest_point_sampling_wrapper
,
"furthest_point_sampling_wrapper"
);
m
.
def
(
"furthest_point_sampling_with_dist_wrapper"
,
&
furthest_point_sampling_with_dist_wrapper
,
"furthest_point_sampling_with_dist_wrapper"
);
}
mmdet3d/ops/furthest_point_sample/src/furthest_point_sample_cuda.cu
deleted
100644 → 0
View file @
9c7270d0
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling_gpu.cu
#include <stdio.h>
#include <stdlib.h>
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline
int
opt_n_threads
(
int
work_size
)
{
const
int
pow_2
=
std
::
log
(
static_cast
<
double
>
(
work_size
))
/
std
::
log
(
2.0
);
return
max
(
min
(
1
<<
pow_2
,
TOTAL_THREADS
),
1
);
}
__device__
void
__update
(
float
*
__restrict__
dists
,
int
*
__restrict__
dists_i
,
int
idx1
,
int
idx2
)
{
const
float
v1
=
dists
[
idx1
],
v2
=
dists
[
idx2
];
const
int
i1
=
dists_i
[
idx1
],
i2
=
dists_i
[
idx2
];
dists
[
idx1
]
=
max
(
v1
,
v2
);
dists_i
[
idx1
]
=
v2
>
v1
?
i2
:
i1
;
}
template
<
unsigned
int
block_size
>
__global__
void
furthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if
(
m
<=
0
)
return
;
__shared__
float
dists
[
block_size
];
__shared__
int
dists_i
[
block_size
];
int
batch_index
=
blockIdx
.
x
;
dataset
+=
batch_index
*
n
*
3
;
temp
+=
batch_index
*
n
;
idxs
+=
batch_index
*
m
;
int
tid
=
threadIdx
.
x
;
const
int
stride
=
block_size
;
int
old
=
0
;
if
(
threadIdx
.
x
==
0
)
idxs
[
0
]
=
old
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
int
besti
=
0
;
float
best
=
-
1
;
float
x1
=
dataset
[
old
*
3
+
0
];
float
y1
=
dataset
[
old
*
3
+
1
];
float
z1
=
dataset
[
old
*
3
+
2
];
for
(
int
k
=
tid
;
k
<
n
;
k
+=
stride
)
{
float
x2
,
y2
,
z2
;
x2
=
dataset
[
k
*
3
+
0
];
y2
=
dataset
[
k
*
3
+
1
];
z2
=
dataset
[
k
*
3
+
2
];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float
d
=
(
x2
-
x1
)
*
(
x2
-
x1
)
+
(
y2
-
y1
)
*
(
y2
-
y1
)
+
(
z2
-
z1
)
*
(
z2
-
z1
);
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
}
dists
[
tid
]
=
best
;
dists_i
[
tid
]
=
besti
;
__syncthreads
();
if
(
block_size
>=
1024
)
{
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
}
if
(
block_size
>=
512
)
{
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
}
if
(
block_size
>=
256
)
{
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
}
if
(
block_size
>=
128
)
{
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
}
if
(
block_size
>=
64
)
{
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
}
if
(
block_size
>=
32
)
{
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
}
if
(
block_size
>=
16
)
{
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
}
if
(
block_size
>=
8
)
{
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
}
if
(
block_size
>=
4
)
{
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
}
if
(
block_size
>=
2
)
{
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
}
old
=
dists_i
[
0
];
if
(
tid
==
0
)
idxs
[
j
]
=
old
;
}
}
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
n
);
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
>
<<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
// Modified from
// https://github.com/qiqihaer/3DSSD-pytorch/blob/master/lib/pointnet2/src/sampling_gpu.cu
template
<
unsigned
int
block_size
>
__global__
void
furthest_point_sampling_with_dist_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, N)
// tmp: (B, N)
// output:
// idx: (B, M)
if
(
m
<=
0
)
return
;
__shared__
float
dists
[
block_size
];
__shared__
int
dists_i
[
block_size
];
int
batch_index
=
blockIdx
.
x
;
dataset
+=
batch_index
*
n
*
n
;
temp
+=
batch_index
*
n
;
idxs
+=
batch_index
*
m
;
int
tid
=
threadIdx
.
x
;
const
int
stride
=
block_size
;
int
old
=
0
;
if
(
threadIdx
.
x
==
0
)
idxs
[
0
]
=
old
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
int
besti
=
0
;
float
best
=
-
1
;
// float x1 = dataset[old * 3 + 0];
// float y1 = dataset[old * 3 + 1];
// float z1 = dataset[old * 3 + 2];
for
(
int
k
=
tid
;
k
<
n
;
k
+=
stride
)
{
// float x2, y2, z2;
// x2 = dataset[k * 3 + 0];
// y2 = dataset[k * 3 + 1];
// z2 = dataset[k * 3 + 2];
// float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) *
// (z2 - z1);
float
d
=
dataset
[
old
*
n
+
k
];
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
}
dists
[
tid
]
=
best
;
dists_i
[
tid
]
=
besti
;
__syncthreads
();
if
(
block_size
>=
1024
)
{
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
}
if
(
block_size
>=
512
)
{
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
}
if
(
block_size
>=
256
)
{
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
}
if
(
block_size
>=
128
)
{
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
}
if
(
block_size
>=
64
)
{
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
}
if
(
block_size
>=
32
)
{
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
}
if
(
block_size
>=
16
)
{
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
}
if
(
block_size
>=
8
)
{
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
}
if
(
block_size
>=
4
)
{
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
}
if
(
block_size
>=
2
)
{
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
}
old
=
dists_i
[
0
];
if
(
tid
==
0
)
idxs
[
j
]
=
old
;
}
}
void
furthest_point_sampling_with_dist_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
// dataset: (B, N, N)
// temp: (B, N)
// output:
// idx: (B, M)
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
n
);
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_with_dist_kernel
<
1024
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_with_dist_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_with_dist_kernel
<
256
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_with_dist_kernel
<
128
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_with_dist_kernel
<
64
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_with_dist_kernel
<
32
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_with_dist_kernel
<
16
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_with_dist_kernel
<
8
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_with_dist_kernel
<
4
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_with_dist_kernel
<
2
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_with_dist_kernel
<
1
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_with_dist_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/furthest_point_sample/utils.py
deleted
100644 → 0
View file @
9c7270d0
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
def
calc_square_dist
(
point_feat_a
,
point_feat_b
,
norm
=
True
):
"""Calculating square distance between a and b.
Args:
point_feat_a (Tensor): (B, N, C) Feature vector of each point.
point_feat_b (Tensor): (B, M, C) Feature vector of each point.
norm (Bool, optional): Whether to normalize the distance.
Default: True.
Returns:
Tensor: (B, N, M) Distance between each pair points.
"""
length_a
=
point_feat_a
.
shape
[
1
]
length_b
=
point_feat_b
.
shape
[
1
]
num_channel
=
point_feat_a
.
shape
[
-
1
]
# [bs, n, 1]
a_square
=
torch
.
sum
(
point_feat_a
.
unsqueeze
(
dim
=
2
).
pow
(
2
),
dim
=-
1
)
# [bs, 1, m]
b_square
=
torch
.
sum
(
point_feat_b
.
unsqueeze
(
dim
=
1
).
pow
(
2
),
dim
=-
1
)
a_square
=
a_square
.
repeat
((
1
,
1
,
length_b
))
# [bs, n, m]
b_square
=
b_square
.
repeat
((
1
,
length_a
,
1
))
# [bs, n, m]
coor
=
torch
.
matmul
(
point_feat_a
,
point_feat_b
.
transpose
(
1
,
2
))
dist
=
a_square
+
b_square
-
2
*
coor
if
norm
:
dist
=
torch
.
sqrt
(
dist
)
/
num_channel
return
dist
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment