Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
c7f73ca2
Commit
c7f73ca2
authored
Apr 07, 2018
by
rusty1s
Browse files
added grid tests
parent
4576030c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
3 deletions
+64
-3
test/test_graclus.py
test/test_graclus.py
+2
-3
test/test_grid.py
test/test_grid.py
+62
-0
No files found.
test/test_graclus.py
View file @
c7f73ca2
...
...
@@ -10,7 +10,6 @@ from .tensor import tensors
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'col'
:
[
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
],
'weight'
:
None
,
},
{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'col'
:
[
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
],
...
...
@@ -47,7 +46,7 @@ def test_graclus_cluster_cpu(tensor, i):
row
=
torch
.
LongTensor
(
data
[
'row'
])
col
=
torch
.
LongTensor
(
data
[
'col'
])
weight
=
data
[
'weight'
]
weight
=
data
.
get
(
'weight'
)
weight
=
weight
if
weight
is
None
else
getattr
(
torch
,
tensor
)(
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
...
...
@@ -62,7 +61,7 @@ def test_graclus_cluster_gpu(tensor, i): # pragma: no cover
row
=
torch
.
cuda
.
LongTensor
(
data
[
'row'
])
col
=
torch
.
cuda
.
LongTensor
(
data
[
'col'
])
weight
=
data
[
'weight'
]
weight
=
data
.
get
(
'weight'
)
weight
=
weight
if
weight
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
...
...
test/test_grid.py
View file @
c7f73ca2
from
itertools
import
product
import
pytest
import
torch
from
torch_cluster
import
grid_cluster
from
.tensor
import
tensors
tests
=
[{
'pos'
:
[
2
,
6
],
'size'
:
[
5
],
'cluster'
:
[
0
,
0
],
},
{
'pos'
:
[
2
,
6
],
'size'
:
[
5
],
'start'
:
[
0
],
'cluster'
:
[
0
,
1
],
},
{
'pos'
:
[[
0
,
0
],
[
11
,
9
],
[
2
,
8
],
[
2
,
2
],
[
8
,
3
]],
'size'
:
[
5
,
5
],
'cluster'
:
[
0
,
5
,
3
,
0
,
1
],
},
{
'pos'
:
[[
0
,
0
],
[
11
,
9
],
[
2
,
8
],
[
2
,
2
],
[
8
,
3
]],
'size'
:
[
5
,
5
],
'end'
:
[
19
,
19
],
'cluster'
:
[
0
,
6
,
4
,
0
,
1
],
}]
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_grid_cluster_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
pos
=
getattr
(
torch
,
tensor
)(
data
[
'pos'
])
size
=
getattr
(
torch
,
tensor
)(
data
[
'size'
])
start
=
data
.
get
(
'start'
)
start
=
start
if
start
is
None
else
getattr
(
torch
,
tensor
)(
start
)
end
=
data
.
get
(
'end'
)
end
=
end
if
end
is
None
else
getattr
(
torch
,
tensor
)(
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
data
[
'cluster'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_grid_cluster_gpu
(
tensor
,
i
):
data
=
tests
[
i
]
pos
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'pos'
])
size
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'size'
])
start
=
data
.
get
(
'start'
)
start
=
start
if
start
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
start
)
end
=
data
.
get
(
'end'
)
end
=
end
if
end
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
data
[
'cluster'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment