Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
d678ae82
Commit
d678ae82
authored
Jul 13, 2019
by
jlevy44
Browse files
changed cuda function names, added sklearn as dependency, fixed cpu computation
parent
2388a521
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
7 deletions
+16
-7
cuda/knn_kernel.cu
cuda/knn_kernel.cu
+8
-4
setup.py
setup.py
+1
-1
torch_cluster/knn.py
torch_cluster/knn.py
+7
-2
No files found.
cuda/knn_kernel.cu
View file @
d678ae82
...
@@ -5,7 +5,9 @@
...
@@ -5,7 +5,9 @@
#define THREADS 1024
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
__device__
double
dotProduct
(
double
*
a
,
double
*
b
,
int
size
)
{
template
<
typename
scalar_t
>
__global__
void
dot
(
double
*
a
,
double
*
b
,
size_t
size
)
{
double
result
=
0
;
double
result
=
0
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
...
@@ -15,8 +17,10 @@ __device__ double dotProduct(double *a, double *b, int size) {
...
@@ -15,8 +17,10 @@ __device__ double dotProduct(double *a, double *b, int size) {
return
result
;
return
result
;
}
}
__device__
double
calc_norm
(
double
*
a
,
int
size
)
{
template
<
typename
scalar_t
>
double
result
=
dotProduct
(
a
,
a
,
size
);
__global__
void
norm
(
double
*
a
,
size_t
size
)
{
double
result
=
dot
(
a
,
a
,
size
);
result
=
sqrt
(
result
);
result
=
sqrt
(
result
);
return
result
;
return
result
;
}
}
...
@@ -48,7 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
...
@@ -48,7 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t
tmp_dist
=
0
;
scalar_t
tmp_dist
=
0
;
if
(
cosine
)
{
if
(
cosine
)
{
tmp_dist
=
calc_
norm
(
x
,
dim
)
*
calc_
norm
(
y
,
dim
)
-
dot
Product
(
x
,
y
,
dim
)
tmp_dist
=
norm
(
x
,
dim
)
*
norm
(
y
,
dim
)
-
dot
(
x
,
y
,
dim
)
}
}
else
{
else
{
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
...
...
setup.py
View file @
d678ae82
...
@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
...
@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__
=
'1.4.3a1'
__version__
=
'1.4.3a1'
url
=
'https://github.com/rusty1s/pytorch_cluster'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'scipy'
]
install_requires
=
[
'scipy'
,
'scikit-learn'
]
setup_requires
=
[
'pytest-runner'
]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
...
...
torch_cluster/knn.py
View file @
d678ae82
...
@@ -68,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
...
@@ -68,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
tree
=
sklearn
.
neighbors
.
KDTree
(
x
.
detach
().
numpy
(),
metric
=
'cosine'
if
cosine
else
'minkowski'
)
#scipy.spatial.cKDTree(x.detach().numpy())
query_opts
=
dict
(
k
=
k
)
if
cosine
:
tree
=
sklearn
.
neighbors
.
KDTree
(
x
.
detach
().
numpy
(),
metric
=
'cosine'
)
else
:
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
query_opts
[
'distance_upper_bound'
]
=
x
.
size
(
1
)
dist
,
col
=
tree
.
query
(
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
)
#, distance_upper_bound=x.size(1)
)
y
.
detach
().
cpu
(),
**
query_opts
)
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
).
view
(
-
1
,
1
).
repeat
(
1
,
k
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
).
view
(
-
1
,
1
).
repeat
(
1
,
k
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment