Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
d2cc3162
Commit
d2cc3162
authored
Aug 22, 2018
by
rusty1s
Browse files
graclus cpu
parent
0a559a4a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
157 additions
and
45 deletions
+157
-45
cpu/graclus.cpp
cpu/graclus.cpp
+74
-27
cpu/utils.h
cpu/utils.h
+50
-0
test/test_graclus.py
test/test_graclus.py
+7
-2
torch_cluster/graclus.py
torch_cluster/graclus.py
+24
-14
torch_cluster/grid.py
torch_cluster/grid.py
+2
-2
No files found.
cpu/graclus.cpp
View file @
d2cc3162
#include <torch/torch.h>
// #include "../include/degree.cpp"
// #include "../include/loop.cpp"
// #include "../include/perm.cpp"
#include "utils.h"
#define ITERATE_NEIGHBORS(NODE, NAME, ROW, COL, ...) \
{ \
for (int64_t e = ROW[NODE]; e < ROW[NODE + 1]; e++) { \
auto NAME = COL[e]; \
__VA_ARGS__; \
} \
}
at
::
Tensor
graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
// std::tie(row, col) = remove_self_loops(row, col);
// std::tie(row, col) = randperm(row, col, num_nodes);
// auto deg = degree(row, num_nodes, row.type().scalarType());
std
::
tie
(
row
,
col
)
=
remove_self_loops
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
rand
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
to_csr
(
row
,
col
);
auto
row_data
=
row
.
data
<
int64_t
>
(),
col_data
=
col
.
data
<
int64_t
>
();
auto
perm
=
randperm
(
num_nodes
);
auto
perm_data
=
perm
.
data
<
int64_t
>
();
auto
cluster
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
auto
cluster_data
=
cluster
.
data
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
)
{
auto
u
=
perm_data
[
i
];
if
(
cluster_data
[
u
]
>=
0
)
continue
;
cluster_data
[
u
]
=
u
;
// auto *row_data = row.data<int64_t>();
// auto *col_data = col.data<int64_t>();
// auto *deg_data = deg.data<int64_t>();
// auto *cluster_data = cluster.data<int64_t>();
// int64_t e_idx = 0, d_idx, r, c;
// while (e_idx < row.size(0)) {
// r = row_data[e_idx];
// if (cluster_data[r] < 0) {
// cluster_data[r] = r;
// for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
// c = col_data[e_idx + d_idx];
// if (cluster_data[c] < 0) {
// cluster_data[r] = std::min(r, c);
// cluster_data[c] = std::min(r, c);
// break;
// }
// }
// }
// e_idx += deg_data[r];
// }
ITERATE_NEIGHBORS
(
u
,
v
,
row_data
,
col_data
,
{
if
(
cluster_data
[
v
]
>=
0
)
continue
;
cluster_data
[
u
]
=
std
::
min
(
u
,
v
);
cluster_data
[
v
]
=
std
::
min
(
u
,
v
);
break
;
});
}
return
cluster
;
}
at
::
Tensor
weighted_graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
,
int64_t
num_nodes
)
{
std
::
tie
(
row
,
col
)
=
remove_self_loops
(
row
,
col
,
weight
);
std
::
tie
(
row
,
col
,
weight
)
=
to_csr
(
row
,
col
,
weight
);
auto
row_data
=
row
.
data
<
int64_t
>
(),
col_data
=
col
.
data
<
int64_t
>
();
auto
perm
=
randperm
(
num_nodes
);
auto
perm_data
=
perm
.
data
<
int64_t
>
();
auto
cluster
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
auto
cluster_data
=
cluster
.
data
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
weight
.
type
(),
"weighted_graclus"
,
[
&
]
{
auto
weight_data
=
weight
.
data
<
scalar_t
>
();
auto
weight_data
=
weight
.
data
<
scalar_t
>
();
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
)
{
auto
u
=
perm_data
[
i
];
if
(
cluster_data
[
u
]
>=
0
)
continue
;
cluster_data
[
u
]
=
u
;
int64_t
v_max
;
scalar_t
w_max
=
0
;
ITERATE_NEIGHBORS
(
u
,
v
,
row_data
,
col_data
,
{
if
(
cluster_data
[
v
]
>=
0
)
continue
;
auto
w
=
weight_data
[
e
];
if
(
w
>=
w_max
)
{
v_max
=
v
;
w_max
=
w
;
}
});
cluster_data
[
u
]
=
std
::
min
(
u
,
v_max
);
cluster_data
[
v_max
]
=
std
::
min
(
u
,
v_max
);
}
});
return
cluster
;
}
...
...
cpu/utils.h
0 → 100644
View file @
d2cc3162
#pragma once
#include <torch/torch.h>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
auto
mask
=
row
!=
col
;
return
make_tuple
(
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
auto
mask
=
row
!=
col
;
return
make_tuple
(
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
),
weight
.
masked_select
(
mask
));
}
at
::
Tensor
randperm
(
int64_t
n
)
{
auto
out
=
at
::
empty
(
n
,
torch
::
CPU
(
at
::
kLong
));
at
::
randperm_out
(
out
,
n
);
return
out
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rand
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
auto
perm
=
randperm
(
row
.
size
(
0
));
return
make_tuple
(
row
.
index_select
(
perm
),
col
.
index_select
(
perm
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
sort_by_row
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
Tensor
perm
;
tie
(
row
,
perm
)
=
row
.
sort
();
col
=
col
.
index_select
(
0
,
perm
);
return
stack
({
row
,
col
},
0
);
}
inline
Tensor
degree
(
Tensor
row
,
int64_t
num_nodes
)
{
auto
zero
=
zeros
(
num_nodes
,
row
.
type
());
auto
one
=
ones
(
row
.
size
(
0
),
row
.
type
());
return
zero
.
scatter_add_
(
0
,
row
,
one
);
}
inline
tuple
<
Tensor
,
Tensor
>
to_csr
(
Tensor
index
,
int64_t
num_nodes
)
{
index
=
sort_by_row
(
index
);
auto
row
=
degree
(
index
[
0
],
num_nodes
).
cumsum
(
0
);
row
=
cat
({
zeros
(
1
,
row
.
type
()),
row
},
0
);
// Prepend zero.
return
make_tuple
(
row
,
index
[
1
]);
}
// std::tie(row, col) = randperm(row, col);
// std::tie(row, col) = to_csr(row, col);
test/graclus.py
→
test/
test_
graclus.py
View file @
d2cc3162
...
...
@@ -15,8 +15,12 @@ tests = [{
'weight'
:
[
1
,
2
,
1
,
3
,
2
,
2
,
3
,
1
,
2
,
1
],
}]
devices
=
[
torch
.
device
(
'cpu'
)]
dtypes
=
[
torch
.
float
]
tests
=
[
tests
[
0
]]
def
assert_correct_graclus
(
row
,
col
,
cluster
):
def
assert_correct
(
row
,
col
,
cluster
):
row
,
col
,
cluster
=
row
.
to
(
'cpu'
),
col
.
to
(
'cpu'
),
cluster
.
to
(
'cpu'
)
n
=
cluster
.
size
(
0
)
...
...
@@ -47,4 +51,5 @@ def test_graclus_cluster(test, dtype, device):
weight
=
tensor
(
test
.
get
(
'weight'
),
dtype
,
device
)
cluster
=
graclus_cluster
(
row
,
col
,
weight
)
assert_correct_graclus
(
row
,
col
,
cluster
)
print
(
cluster
)
# assert_correct(row, col, cluster)
torch_cluster/graclus.py
View file @
d2cc3162
from
.utils.loop
import
remove_self_loops
from
.utils.perm
import
randperm
,
sort_row
,
randperm_sort_row
from
.utils.ffi
import
graclus
# from .utils.loop import remove_self_loops
# from .utils.perm import randperm, sort_row, randperm_sort_row
# from .utils.ffi import graclus
import
torch
import
graclus_cpu
if
torch
.
cuda
.
is_available
():
import
graclus_cuda
def
graclus_cluster
(
row
,
col
,
weight
=
None
,
num_nodes
=
None
):
...
...
@@ -15,22 +21,26 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
Examples::
>>> row = torch.
LongT
ensor([0, 1, 1, 2])
>>> col = torch.
LongT
ensor([1, 0, 2, 1])
>>> row = torch.
t
ensor([0, 1, 1, 2])
>>> col = torch.
t
ensor([1, 0, 2, 1])
>>> weight = torch.Tensor([1, 1, 1, 1])
>>> cluster = graclus_cluster(row, col, weight)
"""
num_nodes
=
row
.
max
().
item
()
+
1
if
num_nodes
is
None
else
num_nodes
if
num_nodes
is
None
:
num_nodes
=
max
(
row
.
max
().
item
(),
col
.
max
().
item
())
+
1
if
row
.
is_cuda
:
row
,
col
=
sort_row
(
row
,
col
)
else
:
row
,
col
=
randperm
(
row
,
col
)
row
,
col
=
randperm_sort_row
(
row
,
col
,
num_nodes
)
op
=
graclus_cuda
if
row
.
is_cuda
else
graclus_cpu
row
,
col
=
remove_self_loops
(
row
,
col
)
cluster
=
row
.
new_empty
((
num_nodes
,
))
graclus
(
cluster
,
row
,
col
,
weight
)
if
weight
is
None
:
cluster
=
op
.
graclus
(
row
,
col
,
num_nodes
)
else
:
cluster
=
op
.
weighted_graclus
(
row
,
col
,
weight
,
num_nodes
)
return
cluster
# if row.is_cuda:
# row, col = sort_row(row, col)
# else:
# row, col = randperm(row, col)
# row, col = randperm_sort_row(row, col, num_nodes)
torch_cluster/grid.py
View file @
d2cc3162
...
...
@@ -28,7 +28,7 @@ def grid_cluster(pos, size, start=None, end=None):
start
=
pos
.
t
().
min
(
dim
=
1
)[
0
]
if
start
is
None
else
start
end
=
pos
.
t
().
max
(
dim
=
1
)[
0
]
if
end
is
None
else
end
op
=
grid_cuda
.
grid
if
pos
.
is_cuda
else
grid_cpu
.
grid
cluster
=
op
(
pos
,
size
,
start
,
end
)
op
=
grid_cuda
if
pos
.
is_cuda
else
grid_cpu
cluster
=
op
.
grid
(
pos
,
size
,
start
,
end
)
return
cluster
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment