Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
df299e7c
Commit
df299e7c
authored
Dec 26, 2021
by
Shaoshuai Shi
Browse files
fix typos: FPS should be farthest point sampling instead of furthest point sampling
parent
ec982888
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
48 additions
and
48 deletions
+48
-48
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
+1
-1
pcdet/models/backbones_3d/pointnet2_backbone.py
pcdet/models/backbones_3d/pointnet2_backbone.py
+1
-1
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
+1
-1
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
+4
-4
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
+1
-1
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
+2
-2
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
+14
-14
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
+2
-2
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
+3
-3
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
+1
-1
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
+2
-2
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
+14
-14
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
+2
-2
No files found.
pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py
View file @
df299e7c
...
@@ -136,7 +136,7 @@ class VoxelSetAbstraction(nn.Module):
...
@@ -136,7 +136,7 @@ class VoxelSetAbstraction(nn.Module):
bs_mask
=
(
batch_indices
==
bs_idx
)
bs_mask
=
(
batch_indices
==
bs_idx
)
sampled_points
=
src_points
[
bs_mask
].
unsqueeze
(
dim
=
0
)
# (1, N, 3)
sampled_points
=
src_points
[
bs_mask
].
unsqueeze
(
dim
=
0
)
# (1, N, 3)
if
self
.
model_cfg
.
SAMPLE_METHOD
==
'FPS'
:
if
self
.
model_cfg
.
SAMPLE_METHOD
==
'FPS'
:
cur_pt_idxs
=
pointnet2_stack_utils
.
f
u
rthest_point_sample
(
cur_pt_idxs
=
pointnet2_stack_utils
.
f
a
rthest_point_sample
(
sampled_points
[:,
:,
0
:
3
].
contiguous
(),
self
.
model_cfg
.
NUM_KEYPOINTS
sampled_points
[:,
:,
0
:
3
].
contiguous
(),
self
.
model_cfg
.
NUM_KEYPOINTS
).
long
()
).
long
()
...
...
pcdet/models/backbones_3d/pointnet2_backbone.py
View file @
df299e7c
...
@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
...
@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
else
:
else
:
last_num_points
=
self
.
num_points_each_layer
[
i
-
1
]
last_num_points
=
self
.
num_points_each_layer
[
i
-
1
]
cur_xyz
=
l_xyz
[
-
1
][
k
*
last_num_points
:
(
k
+
1
)
*
last_num_points
]
cur_xyz
=
l_xyz
[
-
1
][
k
*
last_num_points
:
(
k
+
1
)
*
last_num_points
]
cur_pt_idxs
=
pointnet2_utils_stack
.
f
u
rthest_point_sample
(
cur_pt_idxs
=
pointnet2_utils_stack
.
f
a
rthest_point_sample
(
cur_xyz
[
None
,
:,
:].
contiguous
(),
self
.
num_points_each_layer
[
i
]
cur_xyz
[
None
,
:,
:].
contiguous
(),
self
.
num_points_each_layer
[
i
]
).
long
()[
0
]
).
long
()[
0
]
if
cur_xyz
.
shape
[
0
]
<
self
.
num_points_each_layer
[
i
]:
if
cur_xyz
.
shape
[
0
]
<
self
.
num_points_each_layer
[
i
]:
...
...
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
View file @
df299e7c
...
@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
...
@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
if
new_xyz
is
None
:
if
new_xyz
is
None
:
new_xyz
=
pointnet2_utils
.
gather_operation
(
new_xyz
=
pointnet2_utils
.
gather_operation
(
xyz_flipped
,
xyz_flipped
,
pointnet2_utils
.
f
u
rthest_point_sample
(
xyz
,
self
.
npoint
)
pointnet2_utils
.
f
a
rthest_point_sample
(
xyz
,
self
.
npoint
)
).
transpose
(
1
,
2
).
contiguous
()
if
self
.
npoint
is
not
None
else
None
).
transpose
(
1
,
2
).
contiguous
()
if
self
.
npoint
is
not
None
else
None
for
i
in
range
(
len
(
self
.
groupers
)):
for
i
in
range
(
len
(
self
.
groupers
)):
...
...
pcdet/ops/pointnet2/pointnet2_batch/pointnet2_utils.py
View file @
df299e7c
...
@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
...
@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
from
.
import
pointnet2_batch_cuda
as
pointnet2
from
.
import
pointnet2_batch_cuda
as
pointnet2
class
F
u
rthestPointSampling
(
Function
):
class
F
a
rthestPointSampling
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
xyz
:
torch
.
Tensor
,
npoint
:
int
)
->
torch
.
Tensor
:
def
forward
(
ctx
,
xyz
:
torch
.
Tensor
,
npoint
:
int
)
->
torch
.
Tensor
:
"""
"""
Uses iterative f
u
rthest point sampling to select a set of npoint features that have the largest
Uses iterative f
a
rthest point sampling to select a set of npoint features that have the largest
minimum distance
minimum distance
:param ctx:
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param xyz: (B, N, 3) where N > npoint
...
@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
...
@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
output
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
)
output
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
pointnet2
.
f
u
rthest_point_sampling_wrapper
(
B
,
N
,
npoint
,
xyz
,
temp
,
output
)
pointnet2
.
f
a
rthest_point_sampling_wrapper
(
B
,
N
,
npoint
,
xyz
,
temp
,
output
)
return
output
return
output
@
staticmethod
@
staticmethod
...
@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
...
@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
return
None
,
None
return
None
,
None
furthest_point_sample
=
F
u
rthestPointSampling
.
apply
farthest_point_sample
=
furthest_point_sample
=
F
a
rthestPointSampling
.
apply
class
GatherOperation
(
Function
):
class
GatherOperation
(
Function
):
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/pointnet2_api.cpp
View file @
df299e7c
...
@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"gather_points_wrapper"
,
&
gather_points_wrapper_fast
,
"gather_points_wrapper_fast"
);
m
.
def
(
"gather_points_wrapper"
,
&
gather_points_wrapper_fast
,
"gather_points_wrapper_fast"
);
m
.
def
(
"gather_points_grad_wrapper"
,
&
gather_points_grad_wrapper_fast
,
"gather_points_grad_wrapper_fast"
);
m
.
def
(
"gather_points_grad_wrapper"
,
&
gather_points_grad_wrapper_fast
,
"gather_points_grad_wrapper_fast"
);
m
.
def
(
"f
u
rthest_point_sampling_wrapper"
,
&
f
u
rthest_point_sampling_wrapper
,
"f
u
rthest_point_sampling_wrapper"
);
m
.
def
(
"f
a
rthest_point_sampling_wrapper"
,
&
f
a
rthest_point_sampling_wrapper
,
"f
a
rthest_point_sampling_wrapper"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper_fast
,
"three_nn_wrapper_fast"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper_fast
,
"three_nn_wrapper_fast"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper_fast
,
"three_interpolate_wrapper_fast"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper_fast
,
"three_interpolate_wrapper_fast"
);
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/sampling.cpp
View file @
df299e7c
...
@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
...
@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
}
}
int
f
u
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
int
f
a
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
f
u
rthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
f
a
rthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
return
1
;
return
1
;
}
}
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.cu
View file @
df299e7c
...
@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
...
@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
}
}
template
<
unsigned
int
block_size
>
template
<
unsigned
int
block_size
>
__global__
void
f
u
rthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
__global__
void
f
a
rthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, 3)
// dataset: (B, N, 3)
// tmp: (B, N)
// tmp: (B, N)
...
@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
...
@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
}
}
void
f
u
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
void
f
a
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
// dataset: (B, N, 3)
// dataset: (B, N, 3)
// tmp: (B, N)
// tmp: (B, N)
...
@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
...
@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch
(
n_threads
)
{
switch
(
n_threads
)
{
case
1024
:
case
1024
:
f
u
rthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
case
512
:
f
u
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
case
256
:
f
u
rthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
case
128
:
f
u
rthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
case
64
:
f
u
rthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
case
32
:
f
u
rthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
case
16
:
f
u
rthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
case
8
:
f
u
rthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
case
4
:
f
u
rthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
case
2
:
f
u
rthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
case
1
:
f
u
rthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
default:
f
u
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
f
a
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
}
err
=
cudaGetLastError
();
err
=
cudaGetLastError
();
...
...
pcdet/ops/pointnet2/pointnet2_batch/src/sampling_gpu.h
View file @
df299e7c
...
@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
...
@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
);
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
);
int
f
u
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
int
f
a
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
f
u
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
void
f
a
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
#endif
#endif
pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py
View file @
df299e7c
...
@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
...
@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
return
new_features
,
idx
return
new_features
,
idx
class
F
u
rthestPointSampling
(
Function
):
class
F
a
rthestPointSampling
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
xyz
:
torch
.
Tensor
,
npoint
:
int
):
def
forward
(
ctx
,
xyz
:
torch
.
Tensor
,
npoint
:
int
):
"""
"""
...
@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
...
@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
output
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
)
output
=
torch
.
cuda
.
IntTensor
(
B
,
npoint
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
temp
=
torch
.
cuda
.
FloatTensor
(
B
,
N
).
fill_
(
1e10
)
pointnet2
.
f
u
rthest_point_sampling_wrapper
(
B
,
N
,
npoint
,
xyz
,
temp
,
output
)
pointnet2
.
f
a
rthest_point_sampling_wrapper
(
B
,
N
,
npoint
,
xyz
,
temp
,
output
)
return
output
return
output
@
staticmethod
@
staticmethod
...
@@ -181,7 +181,7 @@ class FurthestPointSampling(Function):
...
@@ -181,7 +181,7 @@ class FurthestPointSampling(Function):
return
None
,
None
return
None
,
None
furthest_point_sample
=
F
u
rthestPointSampling
.
apply
farthest_point_sample
=
furthest_point_sample
=
F
a
rthestPointSampling
.
apply
class
ThreeNN
(
Function
):
class
ThreeNN
(
Function
):
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp
View file @
df299e7c
...
@@ -12,7 +12,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -12,7 +12,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper_stack
,
"ball_query_wrapper_stack"
);
m
.
def
(
"ball_query_wrapper"
,
&
ball_query_wrapper_stack
,
"ball_query_wrapper_stack"
);
m
.
def
(
"voxel_query_wrapper"
,
&
voxel_query_wrapper_stack
,
"voxel_query_wrapper_stack"
);
m
.
def
(
"voxel_query_wrapper"
,
&
voxel_query_wrapper_stack
,
"voxel_query_wrapper_stack"
);
m
.
def
(
"f
u
rthest_point_sampling_wrapper"
,
&
f
u
rthest_point_sampling_wrapper
,
"f
u
rthest_point_sampling_wrapper"
);
m
.
def
(
"f
a
rthest_point_sampling_wrapper"
,
&
f
a
rthest_point_sampling_wrapper
,
"f
a
rthest_point_sampling_wrapper"
);
m
.
def
(
"group_points_wrapper"
,
&
group_points_wrapper_stack
,
"group_points_wrapper_stack"
);
m
.
def
(
"group_points_wrapper"
,
&
group_points_wrapper_stack
,
"group_points_wrapper_stack"
);
m
.
def
(
"group_points_grad_wrapper"
,
&
group_points_grad_wrapper_stack
,
"group_points_grad_wrapper_stack"
);
m
.
def
(
"group_points_grad_wrapper"
,
&
group_points_grad_wrapper_stack
,
"group_points_grad_wrapper_stack"
);
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/sampling.cpp
View file @
df299e7c
...
@@ -21,7 +21,7 @@ extern THCState *state;
...
@@ -21,7 +21,7 @@ extern THCState *state;
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int
f
u
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
int
f
a
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
)
{
CHECK_INPUT
(
points_tensor
);
CHECK_INPUT
(
points_tensor
);
...
@@ -32,6 +32,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
...
@@ -32,6 +32,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float
*
temp
=
temp_tensor
.
data
<
float
>
();
float
*
temp
=
temp_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
f
u
rthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
f
a
rthest_point_sampling_kernel_launcher
(
b
,
n
,
m
,
points
,
temp
,
idx
);
return
1
;
return
1
;
}
}
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.cu
View file @
df299e7c
...
@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
...
@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
template
<
unsigned
int
block_size
>
template
<
unsigned
int
block_size
>
__global__
void
f
u
rthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
__global__
void
f
a
rthest_point_sampling_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
const
float
*
__restrict__
dataset
,
float
*
__restrict__
temp
,
int
*
__restrict__
idxs
)
{
// dataset: (B, N, 3)
// dataset: (B, N, 3)
// tmp: (B, N)
// tmp: (B, N)
...
@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
...
@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
}
}
void
f
u
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
void
f
a
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
)
{
// dataset: (B, N, 3)
// dataset: (B, N, 3)
// tmp: (B, N)
// tmp: (B, N)
...
@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
...
@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch
(
n_threads
)
{
switch
(
n_threads
)
{
case
1024
:
case
1024
:
f
u
rthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
1024
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
512
:
case
512
:
f
u
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
256
:
case
256
:
f
u
rthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
256
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
128
:
case
128
:
f
u
rthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
128
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
64
:
case
64
:
f
u
rthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
64
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
32
:
case
32
:
f
u
rthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
32
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
16
:
case
16
:
f
u
rthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
16
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
8
:
case
8
:
f
u
rthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
8
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
4
:
case
4
:
f
u
rthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
4
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
2
:
case
2
:
f
u
rthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
2
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
case
1
:
case
1
:
f
u
rthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
f
a
rthest_point_sampling_kernel
<
1
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
break
;
default:
default:
f
u
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
f
a
rthest_point_sampling_kernel
<
512
><<<
b
,
n_threads
>>>
(
b
,
n
,
m
,
dataset
,
temp
,
idxs
);
}
}
err
=
cudaGetLastError
();
err
=
cudaGetLastError
();
...
...
pcdet/ops/pointnet2/pointnet2_stack/src/sampling_gpu.h
View file @
df299e7c
...
@@ -6,10 +6,10 @@
...
@@ -6,10 +6,10 @@
#include<vector>
#include<vector>
int
f
u
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
int
f
a
rthest_point_sampling_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
temp_tensor
,
at
::
Tensor
idx_tensor
);
void
f
u
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
void
f
a
rthest_point_sampling_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
const
float
*
dataset
,
float
*
temp
,
int
*
idxs
);
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment