Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e3c3b133
"src/libtorio/ffmpeg/pybind/pybind.cpp" did not exist on "59f067b78838ef49b8b8399496b2a745ad9b2b92"
Commit
e3c3b133
authored
Jun 04, 2019
by
rusty1s
Browse files
flow arg for radius
parent
4047c05d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
13 deletions
+24
-13
setup.py
setup.py
+1
-1
test/test_knn.py
test/test_knn.py
+0
-2
test/test_radius.py
test/test_radius.py
+9
-3
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/radius.py
torch_cluster/radius.py
+13
-6
No files found.
setup.py
View file @
e3c3b133
...
@@ -28,7 +28,7 @@ if CUDA_HOME is not None:
...
@@ -28,7 +28,7 @@ if CUDA_HOME is not None:
[
'cuda/rw.cpp'
,
'cuda/rw_kernel.cu'
]),
[
'cuda/rw.cpp'
,
'cuda/rw_kernel.cu'
]),
]
]
__version__
=
'1.4.
1
'
__version__
=
'1.4.
2
'
url
=
'https://github.com/rusty1s/pytorch_cluster'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'scipy'
]
install_requires
=
[
'scipy'
]
...
...
test/test_knn.py
View file @
e3c3b133
...
@@ -45,12 +45,10 @@ def test_knn_graph(dtype, device):
...
@@ -45,12 +45,10 @@ def test_knn_graph(dtype, device):
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
row
=
row
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
row
=
row
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
test/test_radius.py
View file @
e3c3b133
...
@@ -47,6 +47,12 @@ def test_radius_graph(dtype, device):
...
@@ -47,6 +47,12 @@ def test_radius_graph(dtype, device):
[
+
1
,
-
1
],
[
+
1
,
-
1
],
],
dtype
,
device
)
],
dtype
,
device
)
out
=
radius_graph
(
x
,
r
=
2
)
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'target_to_source'
)
assert
coalesce
(
out
).
tolist
()
==
[[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
],
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]]
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
row
,
col
=
radius_graph
(
x
,
r
=
2
,
flow
=
'source_to_target'
)
row
=
row
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
torch_cluster/__init__.py
View file @
e3c3b133
...
@@ -7,7 +7,7 @@ from .radius import radius, radius_graph
...
@@ -7,7 +7,7 @@ from .radius import radius, radius_graph
from
.sampler
import
neighbor_sampler
from
.sampler
import
neighbor_sampler
from
.rw
import
random_walk
from
.rw
import
random_walk
__version__
=
'1.4.
1
'
__version__
=
'1.4.
2
'
__all__
=
[
__all__
=
[
'graclus_cluster'
,
'graclus_cluster'
,
...
...
torch_cluster/radius.py
View file @
e3c3b133
...
@@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
...
@@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return
torch
.
stack
([
row
[
mask
],
col
[
mask
]],
dim
=
0
)
return
torch
.
stack
([
row
[
mask
],
col
[
mask
]],
dim
=
0
)
def
radius_graph
(
x
,
r
,
batch
=
None
,
loop
=
False
,
max_num_neighbors
=
32
):
def
radius_graph
(
x
,
r
,
batch
=
None
,
loop
=
False
,
max_num_neighbors
=
32
,
flow
=
'source_to_target'
):
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
...
@@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
...
@@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
...
@@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
"""
"""
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
+
1
)
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
edge_index
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
+
1
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
if
not
loop
:
row
,
col
=
edge_index
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
edge_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment