Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
7318c1b8
Commit
7318c1b8
authored
Jul 18, 2019
by
rusty1s
Browse files
clean up cosine distance
parent
9947d77e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
32 deletions
+39
-32
cuda/knn_kernel.cu
cuda/knn_kernel.cu
+19
-21
setup.py
setup.py
+1
-1
test/test_knn.py
test/test_knn.py
+5
-0
torch_cluster/knn.py
torch_cluster/knn.py
+14
-10
No files found.
cuda/knn_kernel.cu
View file @
7318c1b8
...
...
@@ -4,26 +4,24 @@
#define THREADS 1024
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
template
<
typename
scalar_t
>
__global__
void
dot
(
double
*
a
,
double
*
b
,
size_t
size
)
{
double
result
=
0
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
template
<
typename
scalar_t
>
struct
Cosine
{
static
inline
__device__
scalar_t
dot
(
const
scalar_t
*
a
,
const
scalar_t
*
b
,
size_t
size
)
{
scalar_t
result
=
0
;
for
(
ptrdiff_t
i
=
0
;
i
<
size
;
i
++
)
{
result
+=
a
[
i
]
*
b
[
i
];
}
return
result
;
}
}
template
<
typename
scalar_t
>
__global__
void
norm
(
double
*
a
,
size_t
size
)
{
double
result
=
dot
(
a
,
a
,
size
);
result
=
sqrt
(
result
);
return
result
;
}
static
inline
__device__
scalar_t
norm
(
const
scalar_t
*
a
,
size_t
size
)
{
scalar_t
result
=
0
;
for
(
ptrdiff_t
i
=
0
;
i
<
size
;
i
++
)
{
result
+=
a
[
i
]
*
a
[
i
];
}
return
sqrt
(
result
);
}
};
template
<
typename
scalar_t
>
__global__
void
...
...
@@ -52,16 +50,16 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
scalar_t
tmp_dist
=
0
;
if
(
cosine
)
{
tmp_dist
=
norm
(
x
,
dim
)
*
norm
(
y
,
dim
)
-
dot
(
x
,
y
,
dim
)
}
else
{
tmp_dist
=
Cosine
<
scalar_t
>::
norm
(
x
,
dim
)
*
Cosine
<
scalar_t
>::
norm
(
y
,
dim
)
-
Cosine
<
scalar_t
>::
dot
(
x
,
y
,
dim
);
}
else
{
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
tmp_dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
}
for
(
ptrdiff_t
k_idx_1
=
0
;
k_idx_1
<
k
;
k_idx_1
++
)
{
if
(
dist
[
n_y
*
k
+
k_idx_1
]
>
tmp_dist
)
{
for
(
ptrdiff_t
k_idx_2
=
k
-
1
;
k_idx_2
>
k_idx_1
;
k_idx_2
--
)
{
...
...
setup.py
View file @
7318c1b8
...
...
@@ -31,7 +31,7 @@ if CUDA_HOME is not None:
__version__
=
'1.4.3a1'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'scipy'
,
'scikit-learn'
]
install_requires
=
[
'scipy'
]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
...
...
test/test_knn.py
View file @
7318c1b8
...
...
@@ -33,6 +33,11 @@ def test_knn(dtype, device):
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
2
,
3
,
4
,
5
]
if
x
.
is_cuda
:
row
,
col
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
,
cosine
=
True
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
0
,
1
,
4
,
5
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn_graph
(
dtype
,
device
):
...
...
torch_cluster/knn.py
View file @
7318c1b8
import
torch
import
scipy.spatial
import
sklearn.neighbors
if
torch
.
cuda
.
is_available
():
import
torch_cluster.knn_cuda
...
...
@@ -22,6 +21,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
...
...
@@ -57,6 +59,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
if
x
.
is_cuda
:
return
torch_cluster
.
knn_cuda
.
knn
(
x
,
y
,
k
,
batch_x
,
batch_y
,
cosine
)
if
cosine
:
raise
NotImplementedError
(
'Cosine distance not implemented for CPU'
)
# Rescale x and y.
min_xy
=
min
(
x
.
min
().
item
(),
y
.
min
().
item
())
x
,
y
=
x
-
min_xy
,
y
-
min_xy
...
...
@@ -68,14 +73,9 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
query_opts
=
dict
(
k
=
k
)
if
cosine
:
tree
=
sklearn
.
neighbors
.
KDTree
(
x
.
detach
().
numpy
(),
metric
=
'cosine'
)
else
:
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
query_opts
[
'distance_upper_bound'
]
=
x
.
size
(
1
)
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
**
query_opts
)
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
,
distance_upper_bound
=
x
.
size
(
1
))
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
).
view
(
-
1
,
1
).
repeat
(
1
,
k
)
...
...
@@ -85,7 +85,8 @@ def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
,
flow
=
'source_to_target'
,
cosine
=
False
):
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
,
flow
=
'source_to_target'
,
cosine
=
False
):
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
...
...
@@ -100,6 +101,9 @@ def knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=Fals
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment