Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
1edd387e
Commit
1edd387e
authored
Apr 26, 2018
by
rusty1s
Browse files
pytorch 0.4.0
parent
b87eab0b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
42 additions
and
75 deletions
+42
-75
test/tensor.py
test/tensor.py
+0
-4
test/test_graclus.py
test/test_graclus.py
+16
-34
test/test_grid.py
test/test_grid.py
+8
-33
test/utils.py
test/utils.py
+14
-0
torch_cluster/graclus.py
torch_cluster/graclus.py
+1
-1
torch_cluster/utils/ffi.py
torch_cluster/utils/ffi.py
+1
-1
torch_cluster/utils/perm.py
torch_cluster/utils/perm.py
+2
-2
No files found.
test/tensor.py
deleted
100644 → 0
View file @
b87eab0b
tensors
=
[
'ByteTensor'
,
'CharTensor'
,
'ShortTensor'
,
'IntTensor'
,
'LongTensor'
,
'FloatTensor'
,
'DoubleTensor'
]
test/test_graclus.py
View file @
1edd387e
...
@@ -2,10 +2,9 @@ from itertools import product
...
@@ -2,10 +2,9 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
import
numpy
as
np
from
torch_cluster
import
graclus_cluster
from
torch_cluster
import
graclus_cluster
from
.
tensor
import
tensor
s
from
.
utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
'row'
:
[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
],
...
@@ -18,51 +17,34 @@ tests = [{
...
@@ -18,51 +17,34 @@ tests = [{
def
assert_correct_graclus
(
row
,
col
,
cluster
):
def
assert_correct_graclus
(
row
,
col
,
cluster
):
row
,
col
=
row
.
cpu
().
numpy
(),
col
.
cpu
().
numpy
(
)
row
,
col
,
cluster
=
row
.
to
(
'cpu'
),
col
.
to
(
'cpu'
),
cluster
.
to
(
'cpu'
)
cluster
,
n_nodes
=
cluster
.
cpu
().
numpy
(),
cluster
.
size
(
0
)
n
=
cluster
.
size
(
0
)
# Every node was assigned a cluster.
# Every node was assigned a cluster.
assert
cluster
.
min
()
>=
0
assert
cluster
.
min
()
>=
0
# There are no more than two nodes in each cluster.
# There are no more than two nodes in each cluster.
_
,
count
=
np
.
unique
(
cluster
,
return_counts
=
True
)
_
,
index
=
torch
.
unique
(
cluster
,
return_inverse
=
True
)
count
=
torch
.
zeros_like
(
cluster
)
count
.
scatter_add_
(
0
,
index
,
torch
.
ones_like
(
cluster
))
assert
(
count
>
2
).
max
()
==
0
assert
(
count
>
2
).
max
()
==
0
# Cluster value is minimal.
# Cluster value is minimal.
assert
(
cluster
<=
np
.
arange
(
n
_nodes
,
dtype
=
row
.
dtype
)).
sum
()
==
n
_nodes
assert
(
cluster
<=
torch
.
arange
(
n
,
dtype
=
cluster
.
dtype
)).
sum
()
==
n
# Corresponding clusters must be adjacent.
# Corresponding clusters must be adjacent.
for
n
in
range
(
cluster
.
shape
[
0
]
):
for
i
in
range
(
n
):
x
=
cluster
[
col
[
row
==
n
]]
==
cluster
[
n
]
# Neighbors with same cluster
x
=
cluster
[
col
[
row
==
i
]]
==
cluster
[
i
]
# Neighbors with same cluster
y
=
cluster
==
cluster
[
n
]
# Nodes with same cluster
y
=
cluster
==
cluster
[
i
]
# Nodes with same cluster
.
y
[
n
]
=
0
# Do not look at cluster of
node `n
`.
y
[
i
]
=
0
# Do not look at cluster of
`i
`.
assert
x
.
sum
()
==
y
.
sum
()
assert
x
.
sum
()
==
y
.
sum
()
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_graclus_cluster_cpu
(
tensor
,
i
):
def
test_graclus_cluster_cpu
(
test
,
dtype
,
device
):
data
=
tests
[
i
]
row
=
tensor
(
test
[
'row'
],
torch
.
long
,
device
)
col
=
tensor
(
test
[
'col'
],
torch
.
long
,
device
)
row
=
torch
.
LongTensor
(
data
[
'row'
])
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
col
=
torch
.
LongTensor
(
data
[
'col'
])
weight
=
data
.
get
(
'weight'
)
weight
=
weight
if
weight
is
None
else
getattr
(
torch
,
tensor
)(
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct_graclus
(
row
,
col
,
cluster
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_graclus_cluster_gpu
(
tensor
,
i
):
# pragma: no cover
data
=
tests
[
i
]
row
=
torch
.
cuda
.
LongTensor
(
data
[
'row'
])
col
=
torch
.
cuda
.
LongTensor
(
data
[
'col'
])
weight
=
data
.
get
(
'weight'
)
weight
=
weight
if
weight
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct_graclus
(
row
,
col
,
cluster
)
assert_correct_graclus
(
row
,
col
,
cluster
)
test/test_grid.py
View file @
1edd387e
from
itertools
import
product
from
itertools
import
product
import
pytest
import
pytest
import
torch
from
torch_cluster
import
grid_cluster
from
torch_cluster
import
grid_cluster
from
.
tensor
import
tensor
s
from
.
utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'pos'
:
[
2
,
6
],
'pos'
:
[
2
,
6
],
...
@@ -27,36 +26,12 @@ tests = [{
...
@@ -27,36 +26,12 @@ tests = [{
}]
}]
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_grid_cluster_cpu
(
tensor
,
i
):
def
test_grid_cluster_cpu
(
test
,
dtype
,
device
):
data
=
tests
[
i
]
pos
=
tensor
(
test
[
'pos'
],
dtype
,
device
)
size
=
tensor
(
test
[
'size'
],
dtype
,
device
)
pos
=
getattr
(
torch
,
tensor
)(
data
[
'pos'
])
start
=
tensor
(
test
.
get
(
'start'
),
dtype
,
device
)
size
=
getattr
(
torch
,
tensor
)(
data
[
'size'
])
end
=
tensor
(
test
.
get
(
'end'
),
dtype
,
device
)
start
=
data
.
get
(
'start'
)
start
=
start
if
start
is
None
else
getattr
(
torch
,
tensor
)(
start
)
end
=
data
.
get
(
'end'
)
end
=
end
if
end
is
None
else
getattr
(
torch
,
tensor
)(
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
data
[
'cluster'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_grid_cluster_gpu
(
tensor
,
i
):
# pragma: no cover
data
=
tests
[
i
]
pos
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'pos'
])
size
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'size'
])
start
=
data
.
get
(
'start'
)
start
=
start
if
start
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
start
)
end
=
data
.
get
(
'end'
)
end
=
end
if
end
is
None
else
getattr
(
torch
.
cuda
,
tensor
)(
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
assert
cluster
.
tolist
()
==
data
[
'cluster'
]
assert
cluster
.
tolist
()
==
test
[
'cluster'
]
test/utils.py
0 → 100644
View file @
1edd387e
import
torch
from
torch.testing
import
get_all_dtypes
dtypes
=
get_all_dtypes
()
dtypes
.
remove
(
torch
.
half
)
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
'cuda:{}'
.
format
(
torch
.
cuda
.
current_device
()))]
def
tensor
(
x
,
dtype
,
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
torch_cluster/graclus.py
View file @
1edd387e
...
@@ -29,7 +29,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
...
@@ -29,7 +29,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
row
,
col
=
randperm_sort_row
(
row
,
col
,
num_nodes
)
row
,
col
=
randperm_sort_row
(
row
,
col
,
num_nodes
)
row
,
col
=
remove_self_loops
(
row
,
col
)
row
,
col
=
remove_self_loops
(
row
,
col
)
cluster
=
row
.
new
(
num_nodes
)
cluster
=
row
.
new
_empty
(
(
num_nodes
,
)
)
graclus
(
cluster
,
row
,
col
,
weight
)
graclus
(
cluster
,
row
,
col
,
weight
)
return
cluster
return
cluster
torch_cluster/utils/ffi.py
View file @
1edd387e
...
@@ -3,7 +3,7 @@ from .._ext import ffi
...
@@ -3,7 +3,7 @@ from .._ext import ffi
def
get_func
(
name
,
is_cuda
,
tensor
=
None
):
def
get_func
(
name
,
is_cuda
,
tensor
=
None
):
prefix
=
'THCC'
if
is_cuda
else
'TH'
prefix
=
'THCC'
if
is_cuda
else
'TH'
prefix
+=
'Tensor'
if
tensor
is
None
else
t
ype
(
tensor
).
__name__
prefix
+=
'Tensor'
if
tensor
is
None
else
t
ensor
.
type
().
split
(
'.'
)[
-
1
]
return
getattr
(
ffi
,
'{}_{}'
.
format
(
prefix
,
name
))
return
getattr
(
ffi
,
'{}_{}'
.
format
(
prefix
,
name
))
...
...
torch_cluster/utils/perm.py
View file @
1edd387e
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
def
randperm
(
row
,
col
):
def
randperm
(
row
,
col
):
# Randomly reorder row and column indices.
# Randomly reorder row and column indices.
edge_rid
=
torch
.
randperm
(
row
.
size
(
0
))
.
type_as
(
row
)
edge_rid
=
torch
.
randperm
(
row
.
size
(
0
))
return
row
[
edge_rid
],
col
[
edge_rid
]
return
row
[
edge_rid
],
col
[
edge_rid
]
...
@@ -16,7 +16,7 @@ def sort_row(row, col):
...
@@ -16,7 +16,7 @@ def sort_row(row, col):
def
randperm_sort_row
(
row
,
col
,
num_nodes
):
def
randperm_sort_row
(
row
,
col
,
num_nodes
):
# Randomly change row indices to new values.
# Randomly change row indices to new values.
node_rid
=
torch
.
randperm
(
num_nodes
)
.
type_as
(
row
)
node_rid
=
torch
.
randperm
(
num_nodes
)
row
=
node_rid
[
row
]
row
=
node_rid
[
row
]
# Sort row and column indices row-wise.
# Sort row and column indices row-wise.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment