Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
932b96e2
Commit
932b96e2
authored
Apr 01, 2018
by
rusty1s
Browse files
generic test
parent
385e1bca
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
7 deletions
+35
-7
aten/THCC/THCCGreedy.h
aten/THCC/THCCGreedy.h
+1
-0
aten/THCC/THCCGrid.h
aten/THCC/THCCGrid.h
+1
-0
test/__init__.py
test/__init__.py
+0
-0
test/tensor.py
test/tensor.py
+6
-0
test/test_graclus.py
test/test_graclus.py
+27
-7
test/utils/test_perm.py
test/utils/test_perm.py
+0
-0
No files found.
aten/THCC/THCCGreedy.h
View file @
932b96e2
...
...
@@ -5,5 +5,6 @@ void THCCCharGreedy(THCudaLongTensor *cluster, THCudaLongTensor *row, THCudaLo
void
THCCShortGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaShortTensor
*
weight
);
void
THCCIntGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaIntTensor
*
weight
);
void
THCCLongGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
weight
);
void
THCCHalfGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaHalfTensor
*
weight
);
void
THCCFloatGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaTensor
*
weight
);
void
THCCDoubleGreedy
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaDoubleTensor
*
weight
);
aten/THCC/THCCGrid.h
View file @
932b96e2
...
...
@@ -3,5 +3,6 @@ void THCCCharGrid(THCudaLongTensor *cluster, THCudaCharTensor *pos, THCuda
void
THCCShortGrid
(
THCudaLongTensor
*
cluster
,
THCudaShortTensor
*
pos
,
THCudaShortTensor
*
size
,
THCudaLongTensor
*
count
);
void
THCCIntGrid
(
THCudaLongTensor
*
cluster
,
THCudaIntTensor
*
pos
,
THCudaIntTensor
*
size
,
THCudaLongTensor
*
count
);
void
THCCLongGrid
(
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
pos
,
THCudaLongTensor
*
size
,
THCudaLongTensor
*
count
);
void
THCCHalfGrid
(
THCudaLongTensor
*
cluster
,
THCudaHalfTensor
*
pos
,
THCudaHalfTensor
*
size
,
THCudaLongTensor
*
count
);
void
THCCFloatGrid
(
THCudaLongTensor
*
cluster
,
THCudaTensor
*
pos
,
THCudaTensor
*
size
,
THCudaLongTensor
*
count
);
void
THCCDoubleGrid
(
THCudaLongTensor
*
cluster
,
THCudaDoubleTensor
*
pos
,
THCudaDoubleTensor
*
size
,
THCudaLongTensor
*
count
);
test/__init__.py
0 → 100644
View file @
932b96e2
test/tensor.py
0 → 100644
View file @
932b96e2
cpu_tensors
=
[
'ByteTensor'
,
'CharTensor'
,
'ShortTensor'
,
'IntTensor'
,
'LongTensor'
,
'FloatTensor'
,
'DoubleTensor'
]
cuda_tensors
=
[
'cuda.{}'
.
format
(
t
)
for
t
in
cpu_tensors
+
[
'HalfTensor'
]]
test/test_graclus.py
View file @
932b96e2
from
itertools
import
product
import
pytest
import
torch
import
numpy
as
np
from
torch_cluster
import
graclus_cluster
from
.tensor
import
cpu_tensors
,
cuda_tensors
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'col'
:
[
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
],
'weight'
:
None
,
},
{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'col'
:
[
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
],
'weight'
:
[
1
,
2
,
1
,
3
,
2
,
2
,
3
,
1
,
2
,
1
],
}]
def
assert_correct_graclus
(
row
,
col
,
cluster
):
row
,
col
,
cluster
=
row
.
numpy
(),
col
.
numpy
(),
cluster
.
numpy
()
...
...
@@ -15,16 +30,21 @@ def assert_correct_graclus(row, col, cluster):
# Corresponding clusters must be adjacent.
for
n
in
range
(
cluster
.
shape
[
0
]):
assert
(
cluster
[
col
[
row
==
n
]]
==
cluster
[
n
]).
max
()
==
1
x
=
cluster
[
col
[
row
==
n
]]
==
cluster
[
n
]
# Neighbors with same cluster
y
=
cluster
==
cluster
[
n
]
# Nodes with same cluster
y
[
n
]
=
0
# Do not look at cluster of node `n`.
assert
x
.
sum
()
==
y
.
sum
()
def
test_graclus_cluster_cpu
():
row
=
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
])
col
=
torch
.
LongTensor
([
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
])
weight
=
torch
.
Tensor
([
1
,
2
,
1
,
3
,
2
,
2
,
3
,
1
,
2
,
1
])
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
cpu_tensors
,
range
(
len
(
tests
))))
def
test_graclus_cluster_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
cluster
=
graclus_cluster
(
row
,
col
)
assert_correct_graclus
(
row
,
col
,
cluster
)
row
=
torch
.
LongTensor
(
data
[
'row'
])
col
=
torch
.
LongTensor
(
data
[
'col'
])
weight
=
data
[
'weight'
]
weight
=
weight
if
weight
is
None
else
getattr
(
torch
,
tensor
)(
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct_graclus
(
row
,
col
,
cluster
)
test/utils/test_perm.py
0 → 100644
View file @
932b96e2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment