Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b8ece051
Commit
b8ece051
authored
Mar 28, 2018
by
rusty1s
Browse files
tests
parent
9da489a5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
116 additions
and
12 deletions
+116
-12
test/utils/test_consecutive.py
test/utils/test_consecutive.py
+14
-0
test/utils/test_degree.py
test/utils/test_degree.py
+15
-0
test/utils/test_permute.py
test/utils/test_permute.py
+56
-0
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+1
-1
torch_cluster/functions/serial.py
torch_cluster/functions/serial.py
+2
-2
torch_cluster/functions/utils/consecutive.py
torch_cluster/functions/utils/consecutive.py
+6
-4
torch_cluster/functions/utils/permute.py
torch_cluster/functions/utils/permute.py
+22
-5
No files found.
test/utils/test_consecutive.py
0 → 100644
View file @
b8ece051
import
torch
from
torch_cluster.functions.utils.consecutive
import
consecutive
def
test_consecutive
():
vec
=
torch
.
LongTensor
([
0
,
2
,
3
])
assert
consecutive
(
vec
).
tolist
()
==
[
0
,
1
,
2
]
vec
=
torch
.
LongTensor
([
0
,
3
,
2
,
2
,
3
])
assert
consecutive
(
vec
).
tolist
()
==
[
0
,
2
,
1
,
1
,
2
]
vec
=
torch
.
LongTensor
([
0
,
3
,
2
,
2
,
3
])
assert
consecutive
(
vec
,
return_unique
=
True
)[
0
].
tolist
()
==
[
0
,
2
,
1
,
1
,
2
]
assert
consecutive
(
vec
,
return_unique
=
True
)[
1
].
tolist
()
==
[
0
,
2
,
3
]
test/utils/test_degree.py
0 → 100644
View file @
b8ece051
import
torch
from
torch_cluster.functions.utils.degree
import
node_degree
def
test_node_degree
():
row
=
torch
.
LongTensor
([
0
,
1
,
1
,
0
,
0
,
3
,
0
])
expected_degree
=
[
4
,
2
,
0
,
1
]
degree
=
node_degree
(
row
,
4
)
assert
degree
.
type
()
==
torch
.
FloatTensor
().
type
()
assert
degree
.
tolist
()
==
expected_degree
degree
=
node_degree
(
row
,
4
,
out
=
torch
.
LongTensor
())
assert
degree
.
type
()
==
torch
.
LongTensor
().
type
()
assert
degree
.
tolist
()
==
expected_degree
test/utils/test_permute.py
0 → 100644
View file @
b8ece051
import
pytest
import
torch
from
torch_cluster.functions.utils.permute
import
sort
,
permute
def
test_sort_cpu
():
edge_index
=
torch
.
LongTensor
([
[
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
],
[
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
],
])
expected_edge_index
=
[
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
[
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
],
]
assert
sort
(
edge_index
).
tolist
()
==
expected_edge_index
def
test_permute_cpu
():
edge_index
=
torch
.
LongTensor
([
[
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
],
[
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
],
])
node_rid
=
torch
.
LongTensor
([
2
,
1
,
3
,
0
])
edge_rid
=
torch
.
LongTensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
edge_index
=
permute
(
edge_index
,
4
,
node_rid
,
edge_rid
)
expected_edge_index
=
[
[
3
,
3
,
1
,
1
,
1
,
0
,
0
,
2
,
2
,
2
],
[
1
,
2
,
0
,
2
,
3
,
1
,
2
,
0
,
1
,
3
],
]
assert
edge_index
.
tolist
()
==
expected_edge_index
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_sort_gpu
():
# pragma: no cover
edge_index
=
torch
.
cuda
.
LongTensor
([
[
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
],
[
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
],
])
expected_row
=
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
]
assert
sort
(
edge_index
)[
0
].
cpu
().
tolist
()
==
expected_row
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_permute_gpu
():
# pragma: no cover
edge_index
=
torch
.
cuda
.
LongTensor
([
[
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
],
[
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
],
])
node_rid
=
torch
.
cuda
.
LongTensor
([
2
,
1
,
3
,
0
])
edge_index
=
permute
(
edge_index
,
4
,
node_rid
)
expected_row
=
[
3
,
3
,
1
,
1
,
1
,
0
,
0
,
2
,
2
,
2
]
assert
edge_index
[
0
].
cpu
().
tolist
()
==
expected_row
torch_cluster/functions/grid.py
View file @
b8ece051
...
...
@@ -81,7 +81,7 @@ def sparse_grid_cluster(position, size, batch=None, start=None):
position
,
size
,
start
=
_preprocess
(
position
,
size
,
batch
,
start
)
cluster_size
=
_minimal_cluster_size
(
position
,
size
)
cluster
,
C
=
_grid_cluster
(
position
,
size
,
cluster_size
)
cluster
,
u
=
consecutive
(
cluster
)
cluster
,
u
=
consecutive
(
cluster
,
return_unique
=
True
)
if
batch
is
None
:
return
cluster
...
...
torch_cluster/functions/serial.py
View file @
b8ece051
from
.utils.permute
import
random_
permute
from
.utils.permute
import
permute
from
.utils.degree
import
node_degree
from
.utils.ffi
import
get_func
from
.utils.consecutive
import
consecutive
...
...
@@ -7,7 +7,7 @@ from .utils.consecutive import consecutive
def
serial_cluster
(
edge_index
,
batch
=
None
,
num_nodes
=
None
):
num_nodes
=
edge_index
.
max
()
+
1
if
num_nodes
is
None
else
num_nodes
row
,
col
=
random_
permute
(
edge_index
,
num_nodes
)
row
,
col
=
permute
(
edge_index
,
num_nodes
)
degree
=
node_degree
(
row
,
num_nodes
,
out
=
row
.
new
())
cluster
=
edge_index
.
new
(
num_nodes
).
fill_
(
-
1
)
...
...
torch_cluster/functions/utils/consecutive.py
View file @
b8ece051
...
...
@@ -2,7 +2,7 @@ import torch
from
torch_unique
import
unique
def
get_type
(
max
,
cuda
):
def
_
get_type
(
max
,
cuda
):
if
max
<=
255
:
return
torch
.
cuda
.
ByteTensor
if
cuda
else
torch
.
ByteTensor
elif
max
<=
32767
:
# pragma: no cover
...
...
@@ -13,13 +13,15 @@ def get_type(max, cuda):
return
torch
.
cuda
.
LongTensor
if
cuda
else
torch
.
LongTensor
def
consecutive
(
tensor
):
def
consecutive
(
tensor
,
return_unique
=
False
):
size
=
tensor
.
size
()
u
=
unique
(
tensor
.
view
(
-
1
))
len
=
u
[
-
1
]
+
1
max
=
u
.
size
(
0
)
type
=
get_type
(
max
,
tensor
.
is_cuda
)
type
=
_
get_type
(
max
,
tensor
.
is_cuda
)
arg
=
type
(
len
)
arg
[
u
]
=
torch
.
arange
(
0
,
max
,
out
=
type
(
max
))
tensor
=
arg
[
tensor
.
view
(
-
1
)]
return
tensor
.
view
(
size
).
long
(),
u
tensor
=
tensor
.
view
(
size
).
long
()
return
(
tensor
,
u
)
if
return_unique
else
tensor
torch_cluster/functions/utils/permute.py
View file @
b8ece051
import
torch
def
random_permute
(
edge_index
,
num_nodes
):
def
sort
(
edge_index
):
row
,
col
=
edge_index
row
,
perm
=
row
.
sort
()
col
=
col
[
perm
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
def
permute
(
edge_index
,
num_nodes
,
node_rid
=
None
,
edge_rid
=
None
):
num_edges
=
edge_index
.
size
(
1
)
# Randomly reorder row and column indices.
if
edge_rid
is
None
:
edge_rid
=
torch
.
randperm
(
num_edges
).
type_as
(
edge_index
)
row
,
col
=
edge_index
[:,
edge_rid
]
# Randomly change row indices to new values.
if
node_rid
is
None
:
node_rid
=
torch
.
randperm
(
num_nodes
).
type_as
(
edge_index
)
row
=
node_rid
[
row
]
rid
=
torch
.
randperm
(
row
.
size
(
0
))
row
,
col
=
row
[
rid
],
col
[
rid
]
# Sort row and column indices based on changed values.
row
,
col
=
sort
(
torch
.
stack
([
row
,
col
],
dim
=
0
))
_
,
perm
=
rid
[
torch
.
randperm
(
num_nodes
)].
sort
()
row
,
col
=
row
[
perm
],
col
[
perm
]
# Revert previous row value changes to old indices.
row
=
node_rid
.
sort
()[
1
][
row
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment