Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
4db49780
Commit
4db49780
authored
Jul 29, 2020
by
Shaoshuai Shi
Browse files
continune to update cuda codes to support PyTorch 1.1/1.5
parent
11a7e434
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
89 additions
and
90 deletions
+89
-90
README.md
README.md
+1
-1
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
...nse_heads/target_assigner/axis_aligned_target_assigner.py
+4
-4
pcdet/models/roi_heads/target_assigner/proposal_target_layer.py
...models/roi_heads/target_assigner/proposal_target_layer.py
+4
-4
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
+15
-5
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
+3
-3
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
+1
-1
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
+3
-7
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
+4
-4
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
+2
-2
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
+4
-7
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
+7
-7
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
+3
-3
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
+3
-6
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
+17
-17
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
+3
-3
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
+1
-2
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
+13
-13
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
+1
-1
No files found.
README.md
View file @
4db49780
...
...
@@ -85,7 +85,7 @@ All models are trained with 8 GTX 1080Ti GPUs and are available for download.
| | training time | Car | Pedestrian | Cyclist | download |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:---------:|
|
[
PointPillar
](
tools/cfgs/kitti_models/pointpillar.yaml
)
|~1
.5
hours| 77.28 | 52.29 | 62.68 |
[
model-18M
](
https://drive.google.com/file/d/1wMxWTpU1qUoY3DsCH31WJmvJxcjFXKlm/view?usp=sharing
)
|
|
[
PointPillar
](
tools/cfgs/kitti_models/pointpillar.yaml
)
|~1 hours| 77.28 | 52.29 | 62.68 |
[
model-18M
](
https://drive.google.com/file/d/1wMxWTpU1qUoY3DsCH31WJmvJxcjFXKlm/view?usp=sharing
)
|
|
[
SECOND
](
tools/cfgs/kitti_models/second.yaml
)
| ~2 hours | 78.62 | 52.98 | 67.15 |
[
model-20M
](
https://drive.google.com/file/d/1-01zsPOsqanZQqIIyy7FpNXStL3y4jdR/view?usp=sharing
)
|
|
[
PointRCNN
](
tools/cfgs/kitti_models/pointrcnn.yaml
)
| ~3 hours | 78.70 | 54.41 | 72.11 |
[
model-16M
](
https://drive.google.com/file/d/1BCX9wMn-GYAfSOPpyxf6Iv6fc0qKLSiU/view?usp=sharing
)
|
|
[
PointRCNN-IoU
](
tools/cfgs/kitti_models/pointrcnn_iou.yaml
)
| ~3 hours | 78.75 | 58.32 | 71.34 |
[
model-16M
](
https://drive.google.com/file/d/1V0vNZ3lAHpEEt0MlT80eL2f41K2tHm_D/view?usp=sharing
)
|
...
...
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
View file @
4db49780
...
...
@@ -154,7 +154,7 @@ class AxisAlignedTargetAssigner(object):
empty_gt_mask
=
gt_to_anchor_max
==
0
gt_to_anchor_max
[
empty_gt_mask
]
=
-
1
anchors_with_max_overlap
=
torch
.
nonzero
(
anchor_by_gt_overlap
==
gt_to_anchor_max
)[:,
0
]
anchors_with_max_overlap
=
(
anchor_by_gt_overlap
==
gt_to_anchor_max
)
.
nonzero
()
[:,
0
]
gt_inds_force
=
anchor_to_gt_argmax
[
anchors_with_max_overlap
]
labels
[
anchors_with_max_overlap
]
=
gt_classes
[
gt_inds_force
]
gt_ids
[
anchors_with_max_overlap
]
=
gt_inds_force
.
int
()
...
...
@@ -163,11 +163,11 @@ class AxisAlignedTargetAssigner(object):
gt_inds_over_thresh
=
anchor_to_gt_argmax
[
pos_inds
]
labels
[
pos_inds
]
=
gt_classes
[
gt_inds_over_thresh
]
gt_ids
[
pos_inds
]
=
gt_inds_over_thresh
.
int
()
bg_inds
=
torch
.
nonzero
(
anchor_to_gt_max
<
unmatched_threshold
)[:,
0
]
bg_inds
=
(
anchor_to_gt_max
<
unmatched_threshold
)
.
nonzero
()
[:,
0
]
else
:
bg_inds
=
torch
.
arange
(
num_anchors
,
device
=
anchors
.
device
)
fg_inds
=
torch
.
nonzero
(
labels
>
0
)[:,
0
]
fg_inds
=
(
labels
>
0
)
.
nonzero
()
[:,
0
]
if
self
.
pos_fraction
is
not
None
:
num_fg
=
int
(
self
.
pos_fraction
*
self
.
sample_size
)
...
...
@@ -175,7 +175,7 @@ class AxisAlignedTargetAssigner(object):
num_disabled
=
len
(
fg_inds
)
-
num_fg
disable_inds
=
torch
.
randperm
(
len
(
fg_inds
))[:
num_disabled
]
labels
[
disable_inds
]
=
-
1
fg_inds
=
torch
.
nonzero
(
labels
>
0
)[:,
0
]
fg_inds
=
(
labels
>
0
)
.
nonzero
()
[:,
0
]
num_bg
=
self
.
sample_size
-
(
labels
>
0
).
sum
()
if
len
(
bg_inds
)
>
num_bg
:
...
...
pcdet/models/roi_heads/target_assigner/proposal_target_layer.py
View file @
4db49780
...
...
@@ -118,10 +118,10 @@ class ProposalTargetLayer(nn.Module):
fg_rois_per_image
=
int
(
np
.
round
(
self
.
roi_sampler_cfg
.
FG_RATIO
*
self
.
roi_sampler_cfg
.
ROI_PER_IMAGE
))
fg_thresh
=
min
(
self
.
roi_sampler_cfg
.
REG_FG_THRESH
,
self
.
roi_sampler_cfg
.
CLS_FG_THRESH
)
fg_inds
=
torch
.
nonzero
((
max_overlaps
>=
fg_thresh
)).
view
(
-
1
)
easy_bg_inds
=
torch
.
nonzero
((
max_overlaps
<
self
.
roi_sampler_cfg
.
CLS_BG_THRESH_LO
)).
view
(
-
1
)
hard_bg_inds
=
torch
.
nonzero
((
max_overlaps
<
self
.
roi_sampler_cfg
.
REG_FG_THRESH
)
&
(
max_overlaps
>=
self
.
roi_sampler_cfg
.
CLS_BG_THRESH_LO
)).
view
(
-
1
)
fg_inds
=
((
max_overlaps
>=
fg_thresh
)).
nonzero
().
view
(
-
1
)
easy_bg_inds
=
((
max_overlaps
<
self
.
roi_sampler_cfg
.
CLS_BG_THRESH_LO
)).
nonzero
().
view
(
-
1
)
hard_bg_inds
=
((
max_overlaps
<
self
.
roi_sampler_cfg
.
REG_FG_THRESH
)
&
(
max_overlaps
>=
self
.
roi_sampler_cfg
.
CLS_BG_THRESH_LO
)).
nonzero
().
view
(
-
1
)
fg_num_rois
=
fg_inds
.
numel
()
bg_num_rois
=
hard_bg_inds
.
numel
()
+
easy_bg_inds
.
numel
()
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query.cpp
View file @
4db49780
...
...
@@ -14,10 +14,21 @@ All Rights Reserved 2018.
extern
THCState
*
state
;
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int
ball_query_wrapper_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
new_xyz_tensor
);
...
...
@@ -26,7 +37,6 @@ int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
const
float
*
xyz
=
xyz_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
ball_query_kernel_launcher_fast
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
,
stream
);
ball_query_kernel_launcher_fast
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
return
1
;
}
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.cu
View file @
4db49780
...
...
@@ -52,7 +52,7 @@ __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int ns
void
ball_query_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
\
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
,
cudaStream_t
stream
)
{
const
float
*
new_xyz
,
const
float
*
xyz
,
int
*
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
...
...
@@ -63,7 +63,7 @@ void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsam
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
ball_query_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
ball_query_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
n
,
m
,
radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/ball_query_gpu.h
View file @
4db49780
...
...
@@ -10,6 +10,6 @@ int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
at
::
Tensor
new_xyz_tensor
,
at
::
Tensor
xyz_tensor
,
at
::
Tensor
idx_tensor
);
void
ball_query_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
float
radius
,
int
nsample
,
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
xyz
,
const
float
*
new_xyz
,
int
*
idx
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/group_points.cpp
View file @
4db49780
...
...
@@ -22,9 +22,7 @@ int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
group_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
return
1
;
}
...
...
@@ -36,8 +34,6 @@ int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
group_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
return
1
;
}
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.cu
View file @
4db49780
...
...
@@ -31,7 +31,7 @@ __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints,
}
void
group_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
...
...
@@ -40,7 +40,7 @@ void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, in
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
group_points_grad_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
@@ -73,7 +73,7 @@ __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int n
void
group_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
float
*
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
...
...
@@ -82,7 +82,7 @@ void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsa
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
group_points_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/group_points_gpu.h
View file @
4db49780
...
...
@@ -11,12 +11,12 @@ int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
group_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
);
int
group_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate.cpp
View file @
4db49780
...
...
@@ -25,8 +25,7 @@ void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
float
*
dist2
=
dist2_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher_fast
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
three_nn_kernel_launcher_fast
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
}
...
...
@@ -41,8 +40,7 @@ void three_interpolate_wrapper_fast(int b, int c, int m, int n,
float
*
out
=
out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher_fast
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
three_interpolate_kernel_launcher_fast
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
}
void
three_interpolate_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
...
...
@@ -56,6 +54,5 @@ void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher_fast
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
three_interpolate_grad_kernel_launcher_fast
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
}
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.cu
View file @
4db49780
...
...
@@ -60,7 +60,7 @@ __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restric
void
three_nn_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
const
float
*
known
,
float
*
dist2
,
int
*
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
...
...
@@ -71,7 +71,7 @@ void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
three_nn_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
@@ -104,7 +104,7 @@ __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const
}
void
three_interpolate_kernel_launcher_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
...
...
@@ -114,7 +114,7 @@ void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
three_interpolate_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
@@ -149,7 +149,7 @@ __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, c
}
void
three_interpolate_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
...
...
@@ -158,7 +158,7 @@ void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, con
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
three_interpolate_grad_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/interpolate_gpu.h
View file @
4db49780
...
...
@@ -11,20 +11,20 @@ void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
void
three_nn_kernel_launcher_fast
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
known
,
float
*
dist2
,
int
*
idx
);
void
three_interpolate_wrapper_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_kernel_launcher_fast
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
);
void
three_interpolate_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
);
#endif
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
View file @
4db49780
...
...
@@ -21,8 +21,7 @@ int gather_points_wrapper_fast(int b, int c, int n, int npoints,
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
,
stream
);
gather_points_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
return
1
;
}
...
...
@@ -34,8 +33,7 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
gather_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
,
stream
);
gather_points_grad_kernel_launcher_fast
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
return
1
;
}
...
...
@@ -47,7 +45,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
return
1
;
}
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
View file @
4db49780
...
...
@@ -31,7 +31,7 @@ __global__ void gather_points_kernel_fast(int b, int c, int n, int m,
}
void
gather_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
float
*
out
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
...
...
@@ -41,7 +41,7 @@ void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
gather_points_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
@@ -70,7 +70,7 @@ __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const
}
void
gather_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
...
...
@@ -80,7 +80,7 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel_fast
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
gather_points_grad_kernel_fast
<<<
blocks
,
threads
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
...
...
@@ -216,7 +216,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
...
...
@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
View file @
4db49780
...
...
@@ -10,20 +10,20 @@ int gather_points_wrapper_fast(int b, int c, int n, int npoints,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
gather_points_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
);
int
gather_points_grad_wrapper_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
gather_points_grad_kernel_launcher_fast
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
);
int
furthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
#endif
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
View file @
4db49780
...
...
@@ -32,7 +32,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
,
stream
);
furthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
return
1
;
}
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
View file @
4db49780
...
...
@@ -140,7 +140,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
)
{
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
...
...
@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch
(
n_threads
)
{
case
1024
:
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
furthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
furthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
err
=
cudaGetLastError
();
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
View file @
4db49780
...
...
@@ -10,6 +10,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
furthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
,
cudaStream_t
stream
);
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment