Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
8922371e
Commit
8922371e
authored
Dec 26, 2021
by
Shaoshuai Shi
Browse files
Support StackFarthestPointSampling
parent
df299e7c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
234 additions
and
0 deletions
+234
-0
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
+37
-0
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
+1
-0
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
+23
-0
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
+165
-0
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
+8
-0
No files found.
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
View file @
8922371e
...
...
@@ -184,6 +184,43 @@ class FarthestPointSampling(Function):
farthest_point_sample
=
furthest_point_sample
=
FarthestPointSampling
.
apply
class
StackFarthestPointSampling
(
Function
):
@
staticmethod
def
forward
(
ctx
,
xyz
,
xyz_batch_cnt
,
npoint
):
"""
Args:
ctx:
xyz: (N1 + N2 + ..., 3) where N > npoint
xyz_batch_cnt: [N1, N2, ...]
npoint: int, number of features in the sampled set
Returns:
output: (npoint.sum()) tensor containing the set,
npoint: (M1, M2, ...)
"""
assert
xyz
.
is_contiguous
()
and
xyz
.
shape
[
1
]
==
3
batch_size
=
xyz_batch_cnt
.
__len__
()
if
not
isinstance
(
npoint
,
torch
.
Tensor
):
if
not
isinstance
(
npoint
,
list
):
npoint
=
[
npoint
for
i
in
range
(
batch_size
)]
npoint
=
torch
.
tensor
(
npoint
,
device
=
xyz
.
device
).
int
()
N
,
_
=
xyz
.
size
()
temp
=
torch
.
cuda
.
FloatTensor
(
N
).
fill_
(
1e10
)
output
=
torch
.
cuda
.
IntTensor
(
npoint
.
sum
().
item
())
pointnet2
.
stack_farthest_point_sampling_wrapper
(
xyz
,
temp
,
xyz_batch_cnt
,
output
,
npoint
)
return
output
@
staticmethod
def
backward
(
xyz
,
a
=
None
):
return
None
,
None
stack_farthest_point_sample
=
StackFarthestPointSampling
.
apply
class
ThreeNN
(
Function
):
@
staticmethod
def
forward
(
ctx
,
unknown
,
unknown_batch_cnt
,
known
,
known_batch_cnt
):
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
View file @
8922371e
...
...
@@ -13,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"voxel_query_wrapper"
,
&
voxel_query_wrapper_stack
,
"voxel_query_wrapper_stack"
);
m
.
def
(
"farthest_point_sampling_wrapper"
,
&
farthest_point_sampling_wrapper
,
"farthest_point_sampling_wrapper"
);
m
.
def
(
"stack_farthest_point_sampling_wrapper"
,
&
stack_farthest_point_sampling_wrapper
,
"stack_farthest_point_sampling_wrapper"
);
m
.
def
(
"group_points_wrapper"
,
&
group_points_wrapper_stack
,
"group_points_wrapper_stack"
);
m
.
def
(
"group_points_grad_wrapper"
,
&
group_points_grad_wrapper_stack
,
"group_points_grad_wrapper_stack"
);
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
View file @
8922371e
...
...
@@ -35,3 +35,26 @@ int farthest_point_sampling_wrapper(int b, int n, int m,
farthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
return
1
;
}
int
stack_farthest_point_sampling_wrapper
(
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
xyz_batch_cnt_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
num_sampled_points_tensor
)
{
CHECK_INPUT
(
points_tensor
);
CHECK_INPUT
(
temp_tensor
);
CHECK_INPUT
(
idx_tensor
);
CHECK_INPUT
(
xyz_batch_cnt_tensor
);
CHECK_INPUT
(
num_sampled_points_tensor
);
int
batch_size
=
xyz_batch_cnt_tensor
.
size
(
0
);
int
N
=
points_tensor
.
size
(
0
);
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
xyz_batch_cnt
=
xyz_batch_cnt_tensor
.
data
<
int
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
int
*
num_sampled_points
=
num_sampled_points_tensor
.
data
<
int
>
();
stack_farthest_point_sampling_kernel_launcher
(
N
,
batch_size
,
points
,
temp
,
xyz_batch_cnt
,
idx
,
num_sampled_points
);
return
1
;
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
View file @
8922371e
...
...
@@ -182,3 +182,168 @@ void farthest_point_sampling_kernel_launcher(int b, int n, int m,
exit
(
-
1
);
}
}
template
<
unsigned
int
block_size
>
__global__
void
stack_farthest_point_sampling_kernel
(
int
batch_size
,
int
N
,
const
float
*
dataset
,
float
*
temp
,
int
*
xyz_batch_cnt
,
int
*
idxs
,
int
*
num_sampled_points
)
{
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// num_sampled_points: [M1, M2, ...] int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
__shared__
float
dists
[
block_size
];
__shared__
int
dists_i
[
block_size
];
int
bs_idx
=
blockIdx
.
x
;
int
xyz_batch_start_idx
=
0
,
idxs_start_idx
=
0
;
for
(
int
k
=
0
;
k
<
bs_idx
;
k
++
){
xyz_batch_start_idx
+=
xyz_batch_cnt
[
k
];
idxs_start_idx
+=
num_sampled_points
[
k
];
}
dataset
+=
xyz_batch_start_idx
*
3
;
temp
+=
xyz_batch_start_idx
;
idxs
+=
idxs_start_idx
;
int
n
=
xyz_batch_cnt
[
bs_idx
];
int
m
=
num_sampled_points
[
bs_idx
];
int
tid
=
threadIdx
.
x
;
const
int
stride
=
block_size
;
int
old
=
0
;
if
(
threadIdx
.
x
==
0
)
idxs
[
0
]
=
xyz_batch_start_idx
;
__syncthreads
();
for
(
int
j
=
1
;
j
<
m
;
j
++
)
{
int
besti
=
0
;
float
best
=
-
1
;
float
x1
=
dataset
[
old
*
3
+
0
];
float
y1
=
dataset
[
old
*
3
+
1
];
float
z1
=
dataset
[
old
*
3
+
2
];
for
(
int
k
=
tid
;
k
<
n
;
k
+=
stride
)
{
float
x2
,
y2
,
z2
;
x2
=
dataset
[
k
*
3
+
0
];
y2
=
dataset
[
k
*
3
+
1
];
z2
=
dataset
[
k
*
3
+
2
];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float
d
=
(
x2
-
x1
)
*
(
x2
-
x1
)
+
(
y2
-
y1
)
*
(
y2
-
y1
)
+
(
z2
-
z1
)
*
(
z2
-
z1
);
float
d2
=
min
(
d
,
temp
[
k
]);
temp
[
k
]
=
d2
;
besti
=
d2
>
best
?
k
:
besti
;
best
=
d2
>
best
?
d2
:
best
;
}
dists
[
tid
]
=
best
;
dists_i
[
tid
]
=
besti
;
__syncthreads
();
if
(
block_size
>=
1024
)
{
if
(
tid
<
512
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
512
);
}
__syncthreads
();
}
if
(
block_size
>=
512
)
{
if
(
tid
<
256
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
256
);
}
__syncthreads
();
}
if
(
block_size
>=
256
)
{
if
(
tid
<
128
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
128
);
}
__syncthreads
();
}
if
(
block_size
>=
128
)
{
if
(
tid
<
64
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
64
);
}
__syncthreads
();
}
if
(
block_size
>=
64
)
{
if
(
tid
<
32
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
32
);
}
__syncthreads
();
}
if
(
block_size
>=
32
)
{
if
(
tid
<
16
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
16
);
}
__syncthreads
();
}
if
(
block_size
>=
16
)
{
if
(
tid
<
8
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
8
);
}
__syncthreads
();
}
if
(
block_size
>=
8
)
{
if
(
tid
<
4
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
4
);
}
__syncthreads
();
}
if
(
block_size
>=
4
)
{
if
(
tid
<
2
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
2
);
}
__syncthreads
();
}
if
(
block_size
>=
2
)
{
if
(
tid
<
1
)
{
__update
(
dists
,
dists_i
,
tid
,
tid
+
1
);
}
__syncthreads
();
}
old
=
dists_i
[
0
];
if
(
tid
==
0
)
idxs
[
j
]
=
old
+
xyz_batch_start_idx
;
}
}
void
stack_farthest_point_sampling_kernel_launcher
(
int
N
,
int
batch_size
,
const
float
*
dataset
,
float
*
temp
,
int
*
xyz_batch_cnt
,
int
*
idxs
,
int
*
num_sampled_points
)
{
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// npoint: int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
cudaError_t
err
;
unsigned
int
n_threads
=
opt_n_threads
(
N
);
stack_farthest_point_sampling_kernel
<
1024
><<<
batch_size
,
1024
>>>
(
batch_size
,
N
,
dataset
,
temp
,
xyz_batch_cnt
,
idxs
,
num_sampled_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
\ No newline at end of file
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
View file @
8922371e
...
...
@@ -12,4 +12,12 @@ int farthest_point_sampling_wrapper(int b, int n, int m,
void
farthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
int
stack_farthest_point_sampling_wrapper
(
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
xyz_batch_cnt_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
num_sampled_points_tensor
);
void
stack_farthest_point_sampling_kernel_launcher
(
int
N
,
int
batch_size
,
const
float
*
dataset
,
float
*
temp
,
int
*
xyz_batch_cnt
,
int
*
idxs
,
int
*
num_sampled_points
);
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment