Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
fefd2cbe
Commit
fefd2cbe
authored
Jun 04, 2019
by
rusty1s
Browse files
add flow to knn call
parent
52214143
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
6 deletions
+15
-6
test/test_knn.py
test/test_knn.py
+7
-1
torch_cluster/knn.py
torch_cluster/knn.py
+8
-5
No files found.
test/test_knn.py
View file @
fefd2cbe
...
@@ -43,8 +43,14 @@ def test_knn_graph(dtype, device):
...
@@ -43,8 +43,14 @@ def test_knn_graph(dtype, device):
[
+
1
,
-
1
],
[
+
1
,
-
1
],
],
dtype
,
device
)
],
dtype
,
device
)
row
,
col
=
knn_graph
(
x
,
k
=
2
)
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
row
=
row
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
torch_cluster/knn.py
View file @
fefd2cbe
...
@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
...
@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
):
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
,
flow
=
'source_to_target'
):
r
"""Computes graph edges to the nearest :obj:`k` points.
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
Args:
...
@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
...
@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
node to a specific example. (default: :obj:`None`)
node to a specific example. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
...
@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
"""
"""
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
)
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
if
not
loop
:
row
,
col
=
edge_index
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
edge_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment