Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b199bcb0
Commit
b199bcb0
authored
Apr 07, 2018
by
rusty1s
Browse files
fixed gpu tests
parent
9a1f7817
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
8 deletions
+6
-8
test/tensor.py
test/tensor.py
+1
-3
test/test_graclus.py
test/test_graclus.py
+5
-5
No files found.
test/tensor.py
View file @
b199bcb0
cpu_
tensors
=
[
tensors
=
[
'ByteTensor'
,
'CharTensor'
,
'ShortTensor'
,
'IntTensor'
,
'LongTensor'
,
'ByteTensor'
,
'CharTensor'
,
'ShortTensor'
,
'IntTensor'
,
'LongTensor'
,
'FloatTensor'
,
'DoubleTensor'
'FloatTensor'
,
'DoubleTensor'
]
]
gpu_tensors
=
[
'cuda.{}'
.
format
(
t
)
for
t
in
cpu_tensors
+
[
'HalfTensor'
]]
test/test_graclus.py
View file @
b199bcb0
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
torch_cluster
import
graclus_cluster
from
torch_cluster
import
graclus_cluster
from
.tensor
import
cpu_tensors
,
gpu_
tensors
from
.tensor
import
tensors
tests
=
[{
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
...
@@ -19,8 +19,8 @@ tests = [{
...
@@ -19,8 +19,8 @@ tests = [{
def
assert_correct_graclus
(
row
,
col
,
cluster
):
def
assert_correct_graclus
(
row
,
col
,
cluster
):
row
,
col
,
cluster
=
row
.
numpy
(),
col
.
numpy
(),
c
luster
.
numpy
()
row
,
col
=
row
.
cpu
()
.
numpy
(),
c
ol
.
cpu
()
.
numpy
()
n_nodes
=
cluster
.
shape
[
0
]
cluster
,
n_nodes
=
cluster
.
cpu
().
numpy
(),
cluster
.
size
(
0
)
# Every node was assigned a cluster.
# Every node was assigned a cluster.
assert
cluster
.
min
()
>=
0
assert
cluster
.
min
()
>=
0
...
@@ -40,7 +40,7 @@ def assert_correct_graclus(row, col, cluster):
...
@@ -40,7 +40,7 @@ def assert_correct_graclus(row, col, cluster):
assert
x
.
sum
()
==
y
.
sum
()
assert
x
.
sum
()
==
y
.
sum
()
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
cpu_
tensors
,
range
(
len
(
tests
))))
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_graclus_cluster_cpu
(
tensor
,
i
):
def
test_graclus_cluster_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
data
=
tests
[
i
]
...
@@ -55,7 +55,7 @@ def test_graclus_cluster_cpu(tensor, i):
...
@@ -55,7 +55,7 @@ def test_graclus_cluster_cpu(tensor, i):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
gpu_
tensors
,
range
(
len
(
tests
))))
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_graclus_cluster_gpu
(
tensor
,
i
):
def
test_graclus_cluster_gpu
(
tensor
,
i
):
data
=
tests
[
i
]
data
=
tests
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment