Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
4056bf63
Commit
4056bf63
authored
Mar 25, 2018
by
rusty1s
Browse files
added random cluster
parent
65846a61
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
23 deletions
+43
-23
test/test_random.py
test/test_random.py
+4
-3
torch_cluster/functions/permute.py
torch_cluster/functions/permute.py
+7
-6
torch_cluster/functions/random.py
torch_cluster/functions/random.py
+2
-2
torch_cluster/src/cpu.c
torch_cluster/src/cpu.c
+30
-12
No files found.
test/test_random.py
View file @
4056bf63
...
...
@@ -6,7 +6,8 @@ def test_random():
edge_index
=
torch
.
LongTensor
([[
0
,
0
,
0
,
1
,
2
,
3
,
3
,
3
,
4
,
5
,
5
,
5
,
6
,
6
],
[
2
,
3
,
6
,
5
,
0
,
0
,
4
,
5
,
3
,
1
,
3
,
6
,
0
,
3
]])
# edge_attr = torch.Tensor([2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2])
node_
rid
=
torch
.
arange
(
edge_index
.
max
()
+
1
,
out
=
edge_index
.
new
())
edge_rid
=
torch
.
arange
(
edge_index
.
size
(
0
),
out
=
edge_index
.
new
()
)
rid
=
torch
.
arange
(
edge_index
.
max
()
+
1
,
out
=
edge_index
.
new
())
output
=
random_cluster
(
edge_index
,
rid
,
perm_edges
=
False
)
random_cluster
(
edge_index
,
node_rid
,
edge_rid
)
expected_output
=
[
0
,
1
,
2
,
0
,
4
,
1
,
6
]
assert
output
.
tolist
()
==
expected_output
torch_cluster/functions/permute.py
View file @
4056bf63
import
torch
def
permute
(
edge_index
,
num_nodes
,
node_
rid
=
None
,
edge_rid
=
Non
e
):
def
permute
(
edge_index
,
num_nodes
,
rid
=
None
,
perm_edges
=
Tru
e
):
row
,
col
=
edge_index
edge_rid
=
torch
.
randperm
(
row
.
size
(
0
))
if
edge_rid
is
None
else
edge_rid
row
,
col
=
row
[
edge_rid
],
col
[
edge_rid
]
if
perm_edges
:
edge_rid
=
torch
.
randperm
(
row
.
size
(
0
))
row
,
col
=
row
[
edge_rid
],
col
[
edge_rid
]
node_
rid
=
torch
.
randperm
(
num_nodes
)
if
node_
rid
is
None
else
node_
rid
_
,
perm
=
node_
rid
[
row
].
sort
()
rid
=
torch
.
randperm
(
num_nodes
)
if
rid
is
None
else
rid
_
,
perm
=
rid
[
row
].
sort
()
row
,
col
=
row
[
perm
],
col
[
perm
]
return
row
,
col
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
torch_cluster/functions/random.py
View file @
4056bf63
...
...
@@ -3,9 +3,9 @@ from .degree import node_degree
from
.permute
import
permute
def
random_cluster
(
edge_index
,
node_
rid
=
None
,
edge_rid
=
Non
e
,
num_nodes
=
None
):
def
random_cluster
(
edge_index
,
rid
=
None
,
perm_edges
=
Tru
e
,
num_nodes
=
None
):
num_nodes
=
edge_index
.
max
()
+
1
if
num_nodes
is
None
else
num_nodes
row
,
col
=
permute
(
edge_index
,
num_nodes
,
node_rid
,
edge_rid
)
row
,
col
=
permute
(
edge_index
,
num_nodes
,
rid
,
perm_edges
)
degree
=
node_degree
(
row
,
num_nodes
,
out
=
row
.
new
())
cluster
=
edge_index
.
new
(
num_nodes
).
fill_
(
-
1
)
...
...
torch_cluster/src/cpu.c
View file @
4056bf63
...
...
@@ -3,18 +3,36 @@
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real)
void
cluster_random
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
)
{
/* int64_t *output_data = output->storage->data + output->storageOffset; */
/* int64_t *row_data = row->storage->data + row->storageOffset; */
/* int64_t *col_data = col->storage->data + col->storageOffset; */
/* int64_t e, E = THLongTensor_nElement(row), r, c, value; */
/* for (e = 0; e < E; e++) { */
/* r = row_data[e]; c = col_data[e]; */
/* if (output_data[r] == -1 && output_data[c] == -1) { */
/* value = r < c ? r : c; */
/* output_data[r] = value; */
/* output_data[c] = value; */
/* } */
/* } */
int64_t
*
output_data
=
output
->
storage
->
data
+
output
->
storageOffset
;
int64_t
*
row_data
=
row
->
storage
->
data
+
row
->
storageOffset
;
int64_t
*
col_data
=
col
->
storage
->
data
+
col
->
storageOffset
;
int64_t
*
degree_data
=
degree
->
storage
->
data
+
degree
->
storageOffset
;
int64_t
e
=
0
,
row_value
,
col_value
,
i
;
while
(
e
<
THLongTensor_nElement
(
row
))
{
row_value
=
row_data
[
e
];
if
(
output_data
[
row_value
]
<
0
)
{
// Node is unmatched.
// Find next unmatched neighbor.
col_value
=
-
1
;
for
(
i
=
0
;
i
<
degree_data
[
row_value
];
i
++
)
{
col_value
=
col_data
[
e
+
i
];
if
(
output_data
[
col_value
]
<
0
)
break
;
// Neighbor found.
else
col_value
=
-
1
;
}
// Set output.
if
(
col_value
<
0
)
{
output_data
[
row_value
]
=
row_value
;
}
else
{
i
=
row_value
<
col_value
?
row_value
:
col_value
;
output_data
[
row_value
]
=
i
;
output_data
[
col_value
]
=
i
;
}
}
e
+=
degree_data
[
row_value
];
}
}
#include "generic/cpu.c"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment