Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
8c8014b9
Commit
8c8014b9
authored
Mar 13, 2020
by
rusty1s
Browse files
fps fixes
parent
388a2e2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
9 deletions
+32
-9
csrc/cuda/atomics.cuh
csrc/cuda/atomics.cuh
+22
-0
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+10
-9
No files found.
csrc/cuda/atomics.cuh
0 → 100644
View file @
8c8014b9
#pragma once
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
atomicAdd
(
address
,
val
);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
}
#else
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
atomicAdd
(
address
,
val
);
}
#endif
csrc/cuda/fps_cuda.cu
View file @
8c8014b9
...
...
@@ -2,11 +2,12 @@
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
>
{
template
<
typename
scalar_t
>
struct
Dist
{
static
inline
__device__
void
compute
(
int64_t
idx
,
int64_t
start_idx
,
int64_t
end_idx
,
int64_t
old
,
scalar_t
*
best
,
int64_t
*
best_idx
,
...
...
@@ -20,7 +21,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
__syncthreads
();
for
(
int64_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
scalar_t
d
=
src
[(
old
*
dim
)
+
(
i
%
dim
)]
-
src
[
i
];
atom
ic
Add
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
atomAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
__syncthreads
();
...
...
@@ -58,11 +59,11 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t
best_idx
=
0
;
__syncthreads
();
Dist
<
scalar_t
,
Dim
>::
compute
(
thread_idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
src
,
dist
,
tmp_dist
,
dim
);
Dist
<
scalar_t
>::
compute
(
thread_idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
src
,
dist
,
tmp_dist
,
dim
);
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
best_dist
[
thread_
idx
]
=
best
;
best_dist_idx
[
thread_
idx
]
=
best_idx
;
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
...
...
@@ -77,7 +78,7 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
__syncthreads
();
if
(
idx
==
0
)
{
if
(
thread_
idx
==
0
)
{
out
[
m
]
=
best_dist_idx
[
0
];
}
}
...
...
@@ -99,7 +100,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
(
float
)
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
out_ptr
=
torch
::
cat
({
torch
.
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
out_ptr
=
torch
::
cat
({
torch
::
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
torch
::
Tensor
start
;
if
(
random_start
)
{
...
...
@@ -120,7 +121,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
row
ptr
.
data_ptr
<
int64_t
>
(),
src
.
data_ptr
<
scalar_t
>
(),
ptr
.
data_ptr
<
int64_t
>
(),
out_ptr
.
data_ptr
<
int64_t
>
(),
start
.
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
tmp_dist
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int64_t
>
(),
src
.
size
(
1
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment