Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
f4ad453a
Commit
f4ad453a
authored
Nov 13, 2018
by
rusty1s
Browse files
compute tmp dists
parent
ae3c20f5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
10 deletions
+64
-10
cuda/atomics.cuh
cuda/atomics.cuh
+15
-0
cuda/fps_kernel.cu
cuda/fps_kernel.cu
+39
-8
test/test_fps.py
test/test_fps.py
+10
-2
No files found.
cuda/atomics.cuh
0 → 100644
View file @
f4ad453a
#pragma once
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
}
#endif
cuda/fps_kernel.cu
View file @
f4ad453a
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "atomics.cuh"
#include "utils.cuh"
#include "utils.cuh"
#define THREADS 1024
#define THREADS 1024
...
@@ -8,11 +9,39 @@ template <typename scalar_t>
...
@@ -8,11 +9,39 @@ template <typename scalar_t>
__global__
void
__global__
void
fps_kernel
(
scalar_t
*
__restrict__
x
,
int64_t
*
__restrict__
cum_deg
,
fps_kernel
(
scalar_t
*
__restrict__
x
,
int64_t
*
__restrict__
cum_deg
,
int64_t
*
__restrict__
cum_k
,
int64_t
*
__restrict__
start
,
int64_t
*
__restrict__
cum_k
,
int64_t
*
__restrict__
start
,
scalar_t
*
__restrict__
tmp_dist
,
int64_t
*
__restrict__
out
)
{
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
// const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
int64_t
*
__restrict__
out
,
size_t
dim
)
{
// const size_t stride = blockDim.x * gridDim.x;
// for (ptrdiff_t i = index; i < numel; i += stride) {
const
size_t
batch_idx
=
blockIdx
.
x
;
// }
const
size_t
idx
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
// == THREADS
const
size_t
start_idx
=
cum_deg
[
batch_idx
];
const
size_t
end_idx
=
cum_deg
[
batch_idx
+
1
];
int64_t
old
=
start_idx
+
start
[
batch_idx
];
if
(
idx
==
0
)
{
out
[
cum_k
[
batch_idx
]]
=
old
;
}
for
(
ptrdiff_t
m
=
cum_k
[
batch_idx
]
+
1
;
m
<
cum_k
[
batch_idx
+
1
];
m
++
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
stride
)
{
tmp_dist
[
n
]
=
0
;
}
__syncthreads
();
for
(
ptrdiff_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
stride
)
{
scalar_t
d
=
x
[(
old
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
__syncthreads
();
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
stride
)
{
dist
[
n
]
=
min
(
dist
[
n
],
tmp_dist
[
n
]);
}
}
}
}
at
::
Tensor
fps_cuda
(
at
::
Tensor
x
,
at
::
Tensor
batch
,
float
ratio
,
bool
random
)
{
at
::
Tensor
fps_cuda
(
at
::
Tensor
x
,
at
::
Tensor
batch
,
float
ratio
,
bool
random
)
{
...
@@ -34,7 +63,8 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
...
@@ -34,7 +63,8 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
start
=
at
::
zeros
(
batch_size
,
k
.
options
());
start
=
at
::
zeros
(
batch_size
,
k
.
options
());
}
}
auto
tmp_dist
=
at
::
full
(
x
.
size
(
0
),
1e38
,
x
.
options
());
auto
dist
=
at
::
full
(
x
.
size
(
0
),
1e38
,
x
.
options
());
auto
tmp_dist
=
at
::
empty
(
x
.
size
(
0
),
x
.
options
());
auto
k_sum
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
auto
k_sum
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
k_sum
,
cum_k
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpy
(
k_sum
,
cum_k
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
...
@@ -44,10 +74,11 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
...
@@ -44,10 +74,11 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"fps_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"fps_kernel"
,
[
&
]
{
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
x
.
data
<
scalar_t
>
(),
cum_deg
.
data
<
int64_t
>
(),
cum_k
.
data
<
int64_t
>
(),
x
.
data
<
scalar_t
>
(),
cum_deg
.
data
<
int64_t
>
(),
cum_k
.
data
<
int64_t
>
(),
start
.
data
<
int64_t
>
(),
tmp_dist
.
data
<
scalar_t
>
(),
out
.
data
<
int64_t
>
());
start
.
data
<
int64_t
>
(),
dist
.
data
<
scalar_t
>
(),
tmp_dist
.
data
<
scalar_t
>
(),
out
.
data
<
int64_t
>
(),
x
.
size
(
1
));
});
});
return
ou
t
;
return
dis
t
;
}
}
// at::Tensor ifp_cuda(at::Tensor x, at::Tensor batch, float ratio) {
// at::Tensor ifp_cuda(at::Tensor x, at::Tensor batch, float ratio) {
...
...
test/test_fps.py
View file @
f4ad453a
...
@@ -12,8 +12,16 @@ devices = [torch.device('cuda')]
...
@@ -12,8 +12,16 @@ devices = [torch.device('cuda')]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_fps
(
dtype
,
device
):
def
test_fps
(
dtype
,
device
):
x
=
tensor
([[
-
1
,
-
1
],
[
-
1
,
1
],
[
1
,
1
],
[
1
,
-
1
]],
dtype
,
device
)
x
=
tensor
([
x
=
x
.
repeat
(
2
,
1
)
[
-
1
,
-
1
],
[
-
1
,
+
1
],
[
+
1
,
+
1
],
[
+
1
,
-
1
],
[
-
2
,
-
2
],
[
-
2
,
+
2
],
[
+
2
,
+
2
],
[
+
2
,
-
2
],
],
dtype
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
out
=
fps_cuda
.
fps
(
x
,
batch
,
0.5
,
False
)
out
=
fps_cuda
.
fps
(
x
,
batch
,
0.5
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment