Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
9947d77e
Unverified
Commit
9947d77e
authored
Jul 18, 2019
by
Matthias Fey
Committed by
GitHub
Jul 18, 2019
Browse files
Merge pull request #28 from jlevy44/master
Attempt at adding cosine similarity metric
parents
453531a5
d678ae82
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
16 deletions
+49
-16
cuda/knn.cpp
cuda/knn.cpp
+3
-3
cuda/knn_kernel.cu
cuda/knn_kernel.cu
+33
-6
setup.py
setup.py
+1
-1
torch_cluster/knn.py
torch_cluster/knn.py
+12
-6
No files found.
cuda/knn.cpp
View file @
9947d77e
...
...
@@ -4,17 +4,17 @@
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at
::
Tensor
knn_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
size_t
k
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
);
at
::
Tensor
batch_y
,
bool
cosine
);
at
::
Tensor
knn
(
at
::
Tensor
x
,
at
::
Tensor
y
,
size_t
k
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
at
::
Tensor
batch_y
,
bool
cosine
)
{
CHECK_CUDA
(
x
);
IS_CONTIGUOUS
(
x
);
CHECK_CUDA
(
y
);
IS_CONTIGUOUS
(
y
);
CHECK_CUDA
(
batch_x
);
CHECK_CUDA
(
batch_y
);
return
knn_cuda
(
x
,
y
,
k
,
batch_x
,
batch_y
);
return
knn_cuda
(
x
,
y
,
k
,
batch_x
,
batch_y
,
cosine
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/knn_kernel.cu
View file @
9947d77e
...
...
@@ -4,13 +4,34 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
template
<
typename
scalar_t
>
__global__
void
dot
(
double
*
a
,
double
*
b
,
size_t
size
)
{
double
result
=
0
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
result
+=
a
[
i
]
*
b
[
i
];
}
return
result
;
}
template
<
typename
scalar_t
>
__global__
void
norm
(
double
*
a
,
size_t
size
)
{
double
result
=
dot
(
a
,
a
,
size
);
result
=
sqrt
(
result
);
return
result
;
}
template
<
typename
scalar_t
>
__global__
void
knn_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
const
int64_t
*
__restrict__
batch_x
,
const
int64_t
*
__restrict__
batch_y
,
scalar_t
*
__restrict__
dist
,
int64_t
*
__restrict__
row
,
int64_t
*
__restrict__
col
,
size_t
k
,
size_t
dim
)
{
size_t
dim
,
bool
cosine
)
{
const
ptrdiff_t
batch_idx
=
blockIdx
.
x
;
const
ptrdiff_t
idx
=
threadIdx
.
x
;
...
...
@@ -30,10 +51,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
for
(
ptrdiff_t
n_x
=
start_idx_x
;
n_x
<
end_idx_x
;
n_x
++
)
{
scalar_t
tmp_dist
=
0
;
if
(
cosine
)
{
tmp_dist
=
norm
(
x
,
dim
)
*
norm
(
y
,
dim
)
-
dot
(
x
,
y
,
dim
)
}
else
{
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
tmp_dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
}
for
(
ptrdiff_t
k_idx_1
=
0
;
k_idx_1
<
k
;
k_idx_1
++
)
{
if
(
dist
[
n_y
*
k
+
k_idx_1
]
>
tmp_dist
)
{
...
...
@@ -51,7 +78,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
}
at
::
Tensor
knn_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
size_t
k
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
)
{
at
::
Tensor
batch_y
,
bool
cosine
)
{
cudaSetDevice
(
x
.
get_device
());
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch_x
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
...
...
@@ -71,7 +98,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
knn_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
x
.
data
<
scalar_t
>
(),
y
.
data
<
scalar_t
>
(),
batch_x
.
data
<
int64_t
>
(),
batch_y
.
data
<
int64_t
>
(),
dist
.
data
<
scalar_t
>
(),
row
.
data
<
int64_t
>
(),
col
.
data
<
int64_t
>
(),
k
,
x
.
size
(
1
));
col
.
data
<
int64_t
>
(),
k
,
x
.
size
(
1
)
,
cosine
);
});
auto
mask
=
col
!=
-
1
;
...
...
setup.py
View file @
9947d77e
...
...
@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__
=
'1.4.3a1'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'scipy'
]
install_requires
=
[
'scipy'
,
'scikit-learn'
]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
...
...
torch_cluster/knn.py
View file @
9947d77e
import
torch
import
scipy.spatial
import
sklearn.neighbors
if
torch
.
cuda
.
is_available
():
import
torch_cluster.knn_cuda
def
knn
(
x
,
y
,
k
,
batch_x
=
None
,
batch_y
=
None
):
def
knn
(
x
,
y
,
k
,
batch_x
=
None
,
batch_y
=
None
,
cosine
=
False
):
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
...
...
@@ -54,7 +55,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
if
x
.
is_cuda
:
return
torch_cluster
.
knn_cuda
.
knn
(
x
,
y
,
k
,
batch_x
,
batch_y
)
return
torch_cluster
.
knn_cuda
.
knn
(
x
,
y
,
k
,
batch_x
,
batch_y
,
cosine
)
# Rescale x and y.
min_xy
=
min
(
x
.
min
().
item
(),
y
.
min
().
item
())
...
...
@@ -67,9 +68,14 @@ def knn(x, y, k, batch_x=None, batch_y=None):
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
query_opts
=
dict
(
k
=
k
)
if
cosine
:
tree
=
sklearn
.
neighbors
.
KDTree
(
x
.
detach
().
numpy
(),
metric
=
'cosine'
)
else
:
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
query_opts
[
'distance_upper_bound'
]
=
x
.
size
(
1
)
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
,
distance_upper_bound
=
x
.
size
(
1
)
)
y
.
detach
().
cpu
(),
**
query_opts
)
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
).
view
(
-
1
,
1
).
repeat
(
1
,
k
)
...
...
@@ -79,7 +85,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
,
flow
=
'source_to_target'
):
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
,
flow
=
'source_to_target'
,
cosine
=
False
):
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
...
...
@@ -110,7 +116,7 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
)
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
=
cosine
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
mask
=
row
!=
col
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment