Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
5bb8d17b
Unverified
Commit
5bb8d17b
authored
Nov 29, 2022
by
Matthias Fey
Committed by
GitHub
Nov 29, 2022
Browse files
update (#154)
parent
6f222280
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
41 additions
and
18 deletions
+41
-18
test/__init__.py
test/__init__.py
+0
-0
test/test_fps.py
test/test_fps.py
+1
-2
test/test_graclus.py
test/test_graclus.py
+4
-2
test/test_grid.py
test/test_grid.py
+5
-2
test/test_knn.py
test/test_knn.py
+2
-3
test/test_nearest.py
test/test_nearest.py
+1
-2
test/test_radius.py
test/test_radius.py
+2
-3
test/test_rw.py
test/test_rw.py
+9
-4
torch_cluster/testing.py
torch_cluster/testing.py
+17
-0
No files found.
test/__init__.py
deleted
100644 → 0
View file @
6f222280
test/test_fps.py
View file @
5bb8d17b
...
...
@@ -4,8 +4,7 @@ import pytest
import
torch
from
torch
import
Tensor
from
torch_cluster
import
fps
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
@
torch
.
jit
.
script
...
...
test/test_graclus.py
View file @
5bb8d17b
...
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch_cluster
import
graclus_cluster
from
.utils
import
dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
...
...
@@ -42,6 +41,9 @@ def assert_correct(row, col, cluster):
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_graclus_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
col
=
tensor
(
test
[
'col'
],
torch
.
long
,
device
)
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
...
...
test/test_grid.py
View file @
5bb8d17b
from
itertools
import
product
import
pytest
import
torch
from
torch_cluster
import
grid_cluster
from
.utils
import
dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
dtypes
,
tensor
tests
=
[{
'pos'
:
[
2
,
6
],
...
...
@@ -28,6 +28,9 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_grid_cluster
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
pos
=
tensor
(
test
[
'pos'
],
dtype
,
device
)
size
=
tensor
(
test
[
'size'
],
dtype
,
device
)
start
=
tensor
(
test
.
get
(
'start'
),
dtype
,
device
)
...
...
test/test_knn.py
View file @
5bb8d17b
from
itertools
import
product
import
pytest
import
torch
import
scipy.spatial
import
torch
from
torch_cluster
import
knn
,
knn_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
def
to_set
(
edge_index
):
...
...
test/test_nearest.py
View file @
5bb8d17b
...
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch_cluster
import
nearest
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
...
test/test_radius.py
View file @
5bb8d17b
from
itertools
import
product
import
pytest
import
torch
import
scipy.spatial
import
torch
from
torch_cluster
import
radius
,
radius_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
torch_cluster.testing
import
devices
,
grad_dtypes
,
tensor
def
to_set
(
edge_index
):
...
...
test/test_rw.py
View file @
5bb8d17b
import
pytest
import
torch
from
torch_cluster
import
random_walk
from
.utils
import
devices
,
tensor
from
torch_cluster.testing
import
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
...
...
@@ -41,7 +40,10 @@ def test_rw_large_with_edge_indices(device):
walk_length
=
10
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
row
,
col
,
start
,
walk_length
,
return_edge_indices
=
True
,
)
assert
node_seq
[:,
0
].
tolist
()
==
start
.
tolist
()
...
...
@@ -63,7 +65,10 @@ def test_rw_small_with_edge_indices(device):
walk_length
=
4
node_seq
,
edge_seq
=
random_walk
(
row
,
col
,
start
,
walk_length
,
row
,
col
,
start
,
walk_length
,
num_nodes
=
3
,
return_edge_indices
=
True
,
)
...
...
t
est/utils
.py
→
t
orch_cluster/testing
.py
View file @
5bb8d17b
from
typing
import
Any
import
torch
dtypes
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
dtypes
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
devices
+=
[
torch
.
device
(
'cuda:
0
'
)]
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
:
Any
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment