Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
adbb322f
Commit
adbb322f
authored
Jul 23, 2020
by
Shaoshuai Shi
Browse files
add PointNet2 batch version from
https://github.com/sshaoshuai/Pointnet2.PyTorch
parent
30df59ed
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1396 additions
and
0 deletions
+1396
-0
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
+173
-0
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
+290
-0
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
+32
-0
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
+73
-0
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
+15
-0
pcdet/ops/pointnet2/pointnet2_batch/src/cuda_utils.h
pcdet/ops/pointnet2/pointnet2_batch/src/cuda_utils.h
+15
-0
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
+43
-0
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
+92
-0
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
+22
-0
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
+61
-0
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
+168
-0
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
+30
-0
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
+24
-0
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
+53
-0
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
+260
-0
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
+29
-0
setup.py
setup.py
+16
-0
No files found.
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
0 → 100644
View file @
adbb322f
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.
import
pointnet2_utils
from
typing
import
List
class
_PointnetSAModuleBase
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
npoint
=
None
self
.
groupers
=
None
self
.
mlps
=
None
self
.
pool_method
=
'max_pool'
def
forward
(
self
,
xyz
:
torch
.
Tensor
,
features
:
torch
.
Tensor
=
None
,
new_xyz
=
None
)
->
(
torch
.
Tensor
,
torch
.
Tensor
):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, N, C) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list
=
[]
xyz_flipped
=
xyz
.
transpose
(
1
,
2
).
contiguous
()
if
new_xyz
is
None
:
new_xyz
=
pointnet2_utils
.
gather_operation
(
xyz_flipped
,
pointnet2_utils
.
furthest_point_sample
(
xyz
,
self
.
npoint
)
).
transpose
(
1
,
2
).
contiguous
()
if
self
.
npoint
is
not
None
else
None
for
i
in
range
(
len
(
self
.
groupers
)):
new_features
=
self
.
groupers
[
i
](
xyz
,
new_xyz
,
features
)
# (B, C, npoint, nsample)
new_features
=
self
.
mlps
[
i
](
new_features
)
# (B, mlp[-1], npoint, nsample)
if
self
.
pool_method
==
'max_pool'
:
new_features
=
F
.
max_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)]
)
# (B, mlp[-1], npoint, 1)
elif
self
.
pool_method
==
'avg_pool'
:
new_features
=
F
.
avg_pool2d
(
new_features
,
kernel_size
=
[
1
,
new_features
.
size
(
3
)]
)
# (B, mlp[-1], npoint, 1)
else
:
raise
NotImplementedError
new_features
=
new_features
.
squeeze
(
-
1
)
# (B, mlp[-1], npoint)
new_features_list
.
append
(
new_features
)
return
new_xyz
,
torch
.
cat
(
new_features_list
,
dim
=
1
)
class
PointnetSAModuleMSG
(
_PointnetSAModuleBase
):
"""Pointnet set abstraction layer with multiscale grouping"""
def
__init__
(
self
,
*
,
npoint
:
int
,
radii
:
List
[
float
],
nsamples
:
List
[
int
],
mlps
:
List
[
List
[
int
]],
bn
:
bool
=
True
,
use_xyz
:
bool
=
True
,
pool_method
=
'max_pool'
):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super
().
__init__
()
assert
len
(
radii
)
==
len
(
nsamples
)
==
len
(
mlps
)
self
.
npoint
=
npoint
self
.
groupers
=
nn
.
ModuleList
()
self
.
mlps
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
radii
)):
radius
=
radii
[
i
]
nsample
=
nsamples
[
i
]
self
.
groupers
.
append
(
pointnet2_utils
.
QueryAndGroup
(
radius
,
nsample
,
use_xyz
=
use_xyz
)
if
npoint
is
not
None
else
pointnet2_utils
.
GroupAll
(
use_xyz
)
)
mlp_spec
=
mlps
[
i
]
if
use_xyz
:
mlp_spec
[
0
]
+=
3
shared_mlps
=
[]
for
k
in
range
(
len
(
mlp_spec
)
-
1
):
shared_mlps
.
extend
([
nn
.
Conv2d
(
mlp_spec
[
k
],
mlp_spec
[
k
+
1
],
kernel_size
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mlp_spec
[
k
+
1
]),
nn
.
ReLU
()
])
self
.
mlps
.
append
(
nn
.
Sequential
(
*
shared_mlps
))
self
.
pool_method
=
pool_method
class
PointnetSAModule
(
PointnetSAModuleMSG
):
"""Pointnet set abstraction layer"""
def
__init__
(
self
,
*
,
mlp
:
List
[
int
],
npoint
:
int
=
None
,
radius
:
float
=
None
,
nsample
:
int
=
None
,
bn
:
bool
=
True
,
use_xyz
:
bool
=
True
,
pool_method
=
'max_pool'
):
"""
:param mlp: list of int, spec of the pointnet before the global max_pool
:param npoint: int, number of features
:param radius: float, radius of ball
:param nsample: int, number of samples in the ball query
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
"""
super
().
__init__
(
mlps
=
[
mlp
],
npoint
=
npoint
,
radii
=
[
radius
],
nsamples
=
[
nsample
],
bn
=
bn
,
use_xyz
=
use_xyz
,
pool_method
=
pool_method
)
class
PointnetFPModule
(
nn
.
Module
):
r
"""Propigates the features of one set to another"""
def
__init__
(
self
,
*
,
mlp
:
List
[
int
],
bn
:
bool
=
True
):
"""
:param mlp: list of int
:param bn: whether to use batchnorm
"""
super
().
__init__
()
shared_mlps
=
[]
for
k
in
range
(
len
(
mlp
)
-
1
):
shared_mlps
.
extend
([
nn
.
Conv2d
(
mlp
[
k
],
mlp
[
k
+
1
],
kernel_size
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mlp
[
k
+
1
]),
nn
.
ReLU
()
])
self
.
mlp
=
nn
.
Sequential
(
*
shared_mlps
)
def
forward
(
self
,
unknown
:
torch
.
Tensor
,
known
:
torch
.
Tensor
,
unknow_feats
:
torch
.
Tensor
,
known_feats
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
:param known: (B, m, 3) tensor of the xyz positions of the known features
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
:param known_feats: (B, C2, m) tensor of features to be propigated
:return:
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
"""
if
known
is
not
None
:
dist
,
idx
=
pointnet2_utils
.
three_nn
(
unknown
,
known
)
dist_recip
=
1.0
/
(
dist
+
1e-8
)
norm
=
torch
.
sum
(
dist_recip
,
dim
=
2
,
keepdim
=
True
)
weight
=
dist_recip
/
norm
interpolated_feats
=
pointnet2_utils
.
three_interpolate
(
known_feats
,
idx
,
weight
)
else
:
interpolated_feats
=
known_feats
.
expand
(
*
known_feats
.
size
()[
0
:
2
],
unknown
.
size
(
1
))
if
unknow_feats
is
not
None
:
new_features
=
torch
.
cat
([
interpolated_feats
,
unknow_feats
],
dim
=
1
)
# (B, C2 + C1, n)
else
:
new_features
=
interpolated_feats
new_features
=
new_features
.
unsqueeze
(
-
1
)
new_features
=
self
.
mlp
(
new_features
)
return
new_features
.
squeeze
(
-
1
)
if
__name__
==
"__main__"
:
pass
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
0 → 100644
View file @
adbb322f
import
torch
from
torch.autograd
import
Variable
from
torch.autograd
import
Function
import
torch.nn
as
nn
from
typing
import
Tuple
from
.
import
pointnet2_batch_cuda
as
pointnet2
class
FurthestPointSampling
(
Function
):
@
staticmethod
def
forward
(
ctx
,
xyz
:
torch
.
Tensor
,
npoint
:
int
)
->
torch
.
Tensor
:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
:return:
output: (B, npoint) tensor containing the set
"""
assert
xyz
.
is_contiguous
()
B
,
N
,
_
=
xyz
.
size
()
output
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
pointnet2
.
furthest_point_sampling_wrapper
(
B
,
N
,
npoint
,
xyz
,
temp
,
output
)
return
output
@
staticmethod
def
backward
(
xyz
,
a
=
None
):
return
None
,
None
furthest_point_sample
=
FurthestPointSampling
.
apply
class
GatherOperation
(
Function
):
@
staticmethod
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert
features
.
is_contiguous
()
assert
idx
.
is_contiguous
()
B
,
npoint
=
idx
.
size
()
_
,
C
,
N
=
features
.
size
()
output
=
torch
.
cuda
.
FloatTensor
(
B
,
C
,
npoint
)
pointnet2
.
gather_points_wrapper
(
B
,
C
,
N
,
npoint
,
features
,
idx
,
output
)
ctx
.
for_backwards
=
(
idx
,
C
,
N
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
idx
,
C
,
N
=
ctx
.
for_backwards
B
,
npoint
=
idx
.
size
()
grad_features
=
Variable
(
torch
.
cuda
.
FloatTensor
(
B
,
C
,
N
).
zero_
())
grad_out_data
=
grad_out
.
data
.
contiguous
()
pointnet2
.
gather_points_grad_wrapper
(
B
,
C
,
N
,
npoint
,
grad_out_data
,
idx
,
grad_features
.
data
)
return
grad_features
,
None
gather_operation
=
GatherOperation
.
apply
class
ThreeNN
(
Function
):
@
staticmethod
def
forward
(
ctx
,
unknown
:
torch
.
Tensor
,
known
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert
unknown
.
is_contiguous
()
assert
known
.
is_contiguous
()
B
,
N
,
_
=
unknown
.
size
()
m
=
known
.
size
(
1
)
dist2
=
torch
.
cuda
.
FloatTensor
(
B
,
N
,
3
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
N
,
3
)
pointnet2
.
three_nn_wrapper
(
B
,
N
,
m
,
unknown
,
known
,
dist2
,
idx
)
return
torch
.
sqrt
(
dist2
),
idx
@
staticmethod
def
backward
(
ctx
,
a
=
None
,
b
=
None
):
return
None
,
None
three_nn
=
ThreeNN
.
apply
class
ThreeInterpolate
(
Function
):
@
staticmethod
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert
features
.
is_contiguous
()
assert
idx
.
is_contiguous
()
assert
weight
.
is_contiguous
()
B
,
c
,
m
=
features
.
size
()
n
=
idx
.
size
(
1
)
ctx
.
three_interpolate_for_backward
=
(
idx
,
weight
,
m
)
output
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
n
)
pointnet2
.
three_interpolate_wrapper
(
B
,
c
,
m
,
n
,
features
,
idx
,
weight
,
output
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx
,
weight
,
m
=
ctx
.
three_interpolate_for_backward
B
,
c
,
n
=
grad_out
.
size
()
grad_features
=
Variable
(
torch
.
cuda
.
FloatTensor
(
B
,
c
,
m
).
zero_
())
grad_out_data
=
grad_out
.
data
.
contiguous
()
pointnet2
.
three_interpolate_grad_wrapper
(
B
,
c
,
n
,
m
,
grad_out_data
,
idx
,
weight
,
grad_features
.
data
)
return
grad_features
,
None
,
None
three_interpolate
=
ThreeInterpolate
.
apply
class
GroupingOperation
(
Function
):
@
staticmethod
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
:param ctx:
:param features: (B, C, N) tensor of features to group
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
:return:
output: (B, C, npoint, nsample) tensor
"""
assert
features
.
is_contiguous
()
assert
idx
.
is_contiguous
()
B
,
nfeatures
,
nsample
=
idx
.
size
()
_
,
C
,
N
=
features
.
size
()
output
=
torch
.
cuda
.
FloatTensor
(
B
,
C
,
nfeatures
,
nsample
)
pointnet2
.
group_points_wrapper
(
B
,
C
,
N
,
nfeatures
,
nsample
,
features
,
idx
,
output
)
ctx
.
for_backwards
=
(
idx
,
N
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
:param ctx:
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
:return:
grad_features: (B, C, N) gradient of the features
"""
idx
,
N
=
ctx
.
for_backwards
B
,
C
,
npoint
,
nsample
=
grad_out
.
size
()
grad_features
=
Variable
(
torch
.
cuda
.
FloatTensor
(
B
,
C
,
N
).
zero_
())
grad_out_data
=
grad_out
.
data
.
contiguous
()
pointnet2
.
group_points_grad_wrapper
(
B
,
C
,
N
,
npoint
,
nsample
,
grad_out_data
,
idx
,
grad_features
.
data
)
return
grad_features
,
None
grouping_operation
=
GroupingOperation
.
apply
class
BallQuery
(
Function
):
@
staticmethod
def
forward
(
ctx
,
radius
:
float
,
nsample
:
int
,
xyz
:
torch
.
Tensor
,
new_xyz
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert
new_xyz
.
is_contiguous
()
assert
xyz
.
is_contiguous
()
B
,
N
,
_
=
xyz
.
size
()
npoint
=
new_xyz
.
size
(
1
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
,
nsample
).
zero_
()
pointnet2
.
ball_query_wrapper
(
B
,
N
,
npoint
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
)
return
idx
@
staticmethod
def
backward
(
ctx
,
a
=
None
):
return
None
,
None
,
None
,
None
ball_query
=
BallQuery
.
apply
class
QueryAndGroup
(
nn
.
Module
):
def
__init__
(
self
,
radius
:
float
,
nsample
:
int
,
use_xyz
:
bool
=
True
):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super
().
__init__
()
self
.
radius
,
self
.
nsample
,
self
.
use_xyz
=
radius
,
nsample
,
use_xyz
def
forward
(
self
,
xyz
:
torch
.
Tensor
,
new_xyz
:
torch
.
Tensor
,
features
:
torch
.
Tensor
=
None
)
->
Tuple
[
torch
.
Tensor
]:
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx
=
ball_query
(
self
.
radius
,
self
.
nsample
,
xyz
,
new_xyz
)
xyz_trans
=
xyz
.
transpose
(
1
,
2
).
contiguous
()
grouped_xyz
=
grouping_operation
(
xyz_trans
,
idx
)
# (B, 3, npoint, nsample)
grouped_xyz
-=
new_xyz
.
transpose
(
1
,
2
).
unsqueeze
(
-
1
)
if
features
is
not
None
:
grouped_features
=
grouping_operation
(
features
,
idx
)
if
self
.
use_xyz
:
new_features
=
torch
.
cat
([
grouped_xyz
,
grouped_features
],
dim
=
1
)
# (B, C + 3, npoint, nsample)
else
:
new_features
=
grouped_features
else
:
assert
self
.
use_xyz
,
"Cannot have not features and not use xyz as a feature!"
new_features
=
grouped_xyz
return
new_features
class
GroupAll
(
nn
.
Module
):
def
__init__
(
self
,
use_xyz
:
bool
=
True
):
super
().
__init__
()
self
.
use_xyz
=
use_xyz
def
forward
(
self
,
xyz
:
torch
.
Tensor
,
new_xyz
:
torch
.
Tensor
,
features
:
torch
.
Tensor
=
None
):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz
=
xyz
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
if
features
is
not
None
:
grouped_features
=
features
.
unsqueeze
(
2
)
if
self
.
use_xyz
:
new_features
=
torch
.
cat
([
grouped_xyz
,
grouped_features
],
dim
=
1
)
# (B, 3 + C, 1, N)
else
:
new_features
=
grouped_features
else
:
new_features
=
grouped_xyz
return
new_features
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
0 → 100644
View file @
adbb322f
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"
extern
THCState
*
state
;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int
ball_query_wrapper_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
new_xyz_tensor
);
CHECK_INPUT
(
xyz_tensor
);
const
float
*
new_xyz
=
new_xyz_tensor
.
data
<
float
>
();
const
float
*
xyz
=
xyz_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
ball_query_kernel_launcher_fast
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
,
stream
);
return
1
;
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
0 → 100644
View file @
adbb322f
/*
batch version of ball query, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_gpu.h"
#include "cuda_utils.h"
__global__
void
ball_query_kernel_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
__restrict__
new_xyz
,
const
float
*
__restrict__
xyz
,
int
*
__restrict__
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
m
)
return
;
new_xyz
+=
bs_idx
*
m
*
3
+
pt_idx
*
3
;
xyz
+=
bs_idx
*
n
*
3
;
idx
+=
bs_idx
*
m
*
nsample
+
pt_idx
*
nsample
;
float
radius2
=
radius
*
radius
;
float
new_x
=
new_xyz
[
0
];
float
new_y
=
new_xyz
[
1
];
float
new_z
=
new_xyz
[
2
];
int
cnt
=
0
;
for
(
int
k
=
0
;
k
<
n
;
++
k
)
{
float
x
=
xyz
[
k
*
3
+
0
];
float
y
=
xyz
[
k
*
3
+
1
];
float
z
=
xyz
[
k
*
3
+
2
];
float
d2
=
(
new_x
-
x
)
*
(
new_x
-
x
)
+
(
new_y
-
y
)
*
(
new_y
-
y
)
+
(
new_z
-
z
)
*
(
new_z
-
z
);
if
(
d2
<
radius2
){
if
(
cnt
==
0
){
for
(
int
l
=
0
;
l
<
nsample
;
++
l
)
{
idx
[
l
]
=
k
;
}
}
idx
[
cnt
]
=
k
;
++
cnt
;
if
(
cnt
>=
nsample
)
break
;
}
}
}
void
ball_query_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
\
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
,
cudaStream_t
stream
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
ball_query_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
0 → 100644
View file @
adbb322f
#ifndef _BALL_QUERY_GPU_H
#define _BALL_QUERY_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int
ball_query_wrapper_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
);
void
ball_query_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
,
cudaStream_t
stream
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/cuda_utils.h
0 → 100644
View file @
adbb322f
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include <cmath>
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
inline
int
opt_n_threads
(
int
work_size
)
{
const
int
pow_2
=
std
::
log
(
static_cast
<
double
>
(
work_size
))
/
std
::
log
(
2
.
0
);
return
max
(
min
(
1
<<
pow_2
,
TOTAL_THREADS
),
1
);
}
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
0 → 100644
View file @
adbb322f
/*
batch version of point grouping, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"
extern
THCState
*
state
;
int
group_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
}
int
group_points_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
return
1
;
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
0 → 100644
View file @
adbb322f
/*
batch version of point grouping, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "group_points_gpu.h"
__global__
void
group_points_grad_kernel_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
]
,
grad_out
[
0
]);
}
void
group_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
group_points_kernel_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
out
[
out_idx
]
=
points
[
in_idx
];
}
void
group_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
0 → 100644
View file @
adbb322f
#ifndef _GROUP_POINTS_GPU_H
#define _GROUP_POINTS_GPU_H
#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
int
group_points_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
group_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
group_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
0 → 100644
View file @
adbb322f
/*
batch version of point interpolation, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"
extern
THCState
*
state
;
void
three_nn_wrapper_fast
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
unknown
=
unknown_tensor
.
data
<
float
>
();
const
float
*
known
=
known_tensor
.
data
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher_fast
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
}
void
three_interpolate_wrapper_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data
<
float
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher_fast
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
}
void
three_interpolate_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data
<
float
>
();
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher_fast
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
0 → 100644
View file @
adbb322f
/*
batch version of point interpolation, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "interpolate_gpu.h"
__global__
void
three_nn_kernel_fast
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
void
three_nn_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
three_interpolate_kernel_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
}
void
three_interpolate_kernel_launcher_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
three_interpolate_grad_kernel_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
+
pt_idx
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
}
void
three_interpolate_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
0 → 100644
View file @
adbb322f
#ifndef _INTERPOLATE_GPU_H
#define _INTERPOLATE_GPU_H
#include <torch/serialize/tensor.h>
#include<vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
void
three_nn_wrapper_fast
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
void
three_nn_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
void
three_interpolate_wrapper_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_kernel_launcher_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
void
three_interpolate_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
0 → 100644
View file @
adbb322f
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper_fast
,
"ball_query_wrapper_fast"
);
m
.
def
(
"group_points_wrapper"
,
&
group_points_wrapper_fast
,
"group_points_wrapper_fast"
);
m
.
def
(
"group_points_grad_wrapper"
,
&
group_points_grad_wrapper_fast
,
"group_points_grad_wrapper_fast"
);
m
.
def
(
"gather_points_wrapper"
,
&
gather_points_wrapper_fast
,
"gather_points_wrapper_fast"
);
m
.
def
(
"gather_points_grad_wrapper"
,
&
gather_points_grad_wrapper_fast
,
"gather_points_grad_wrapper_fast"
);
m
.
def
(
"furthest_point_sampling_wrapper"
,
&
furthest_point_sampling_wrapper
,
"furthest_point_sampling_wrapper"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper_fast
,
"three_nn_wrapper_fast"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper_fast
,
"three_interpolate_wrapper_fast"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper_fast
,
"three_interpolate_grad_wrapper_fast"
);
}
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
0 → 100644
View file @
adbb322f
/*
batch version of point sampling and gathering, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>
#include "sampling_gpu.h"
extern
THCState
*
state
;
int
gather_points_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
){
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
,
stream
);
return
1
;
}
int
gather_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
}
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
return
1
;
}
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
0 → 100644
View file @
adbb322f
/*
batch version of point sampling and gathering, modified from the original implementation of official PointNet++ codes.
Written by Shaoshuai Shi
All Rights Reserved 2018.
*/
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "sampling_gpu.h"
__global__
void
gather_points_kernel_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
}
void
gather_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
gather_points_grad_kernel_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
}
void
gather_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__device__
void
__update
(
float
*
__restrict__
dists
,
int
*
__restrict__
dists_i
,
int
idx1
,
int
idx2
){
const
float
v1
=
dists
[
idx1
],
v2
=
dists
[
idx2
];
const
int
i1
=
dists_i
[
idx1
],
i2
=
dists_i
[
idx2
];
dists
[
idx1
]
=
max
(
v1
,
v2
);
dists_i
[
idx1
]
=
v2
>
v1
?
i2
:
i1
;
}
template
<
unsigned
int
block_size
>
__global__
void
furthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if
(
m
<=
0
)
return
;
__shared__
float
dists
[
block_size
];
__shared__
int
dists_i
[
block_size
];
int
batch_index
=
blockIdx
.
x
;
dataset
+=
batch_index
*
n
*
3
;
temp
+=
batch_index
*
n
;
idxs
+=
batch_index
*
m
;
int
tid
=
threadIdx
.
x
;
const
int
stride
=
block_size
;
int
old
=
0
;
if
(
threadIdx
.
x
==
0
)
idxs
[
0
]
=
old
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
int
besti
=
0
;
float
best
=
-
1
;
float
x1
=
dataset
[
old
*
3
+
0
];
float
y1
=
dataset
[
old
*
3
+
1
];
float
z1
=
dataset
[
old
*
3
+
2
];
for
(
int
k
=
tid
;
k
<
n
;
k
+=
stride
)
{
float
x2
,
y2
,
z2
;
x2
=
dataset
[
k
*
3
+
0
];
y2
=
dataset
[
k
*
3
+
1
];
z2
=
dataset
[
k
*
3
+
2
];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float
d
=
(
x2
-
x1
)
*
(
x2
-
x1
)
+
(
y2
-
y1
)
*
(
y2
-
y1
)
+
(
z2
-
z1
)
*
(
z2
-
z1
);
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
}
dists
[
tid
]
=
best
;
dists_i
[
tid
]
=
besti
;
__syncthreads
();
if
(
block_size
>=
1024
)
{
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
}
if
(
block_size
>=
512
)
{
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
}
if
(
block_size
>=
256
)
{
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
}
if
(
block_size
>=
128
)
{
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
}
if
(
block_size
>=
64
)
{
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
}
if
(
block_size
>=
32
)
{
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
}
if
(
block_size
>=
16
)
{
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
}
if
(
block_size
>=
8
)
{
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
}
if
(
block_size
>=
4
)
{
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
}
if
(
block_size
>=
2
)
{
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
}
old
=
dists_i
[
0
];
if
(
tid
==
0
)
idxs
[
j
]
=
old
;
}
}
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
n
);
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
0 → 100644
View file @
adbb322f
#ifndef _SAMPLING_GPU_H
#define _SAMPLING_GPU_H
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include<vector>
int
gather_points_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
gather_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
gather_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
gather_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
#endif
setup.py
View file @
adbb322f
...
@@ -82,6 +82,22 @@ if __name__ == '__main__':
...
@@ -82,6 +82,22 @@ if __name__ == '__main__':
'src/interpolate_gpu.cu'
,
'src/interpolate_gpu.cu'
,
],
],
),
),
make_cuda_ext
(
name
=
'pointnet2_batch_cuda'
,
module
=
'pcdet.ops.pointnet2.pointnet2_batch'
,
sources
=
[
'src/pointnet2_api.cpp'
,
'src/ball_query.cpp'
,
'src/ball_query_gpu.cu'
,
'src/group_points.cpp'
,
'src/group_points_gpu.cu'
,
'src/interpolate.cpp'
,
'src/interpolate_gpu.cu'
,
'src/sampling.cpp'
,
'src/sampling_gpu.cu'
,
],
),
],
],
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment