Commit df2ed804 authored by rusty1s's avatar rusty1s
Browse files

graclus done

parent dcd88f5a
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col) {
/* at::Tensor perm; */
/* std::tie(row, perm) = row.sort(); */
/* col = col.index_select(0, perm); */
/* TODO: randperm */
/* TODO: randperm_sort_row */
return { row, col };
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) {
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
return {row, col};
}
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index));
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = randperm(row, col);
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes);
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>();
int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
int64_t e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) {
r = row_data[e_idx];
if (cluster_data[r] < 0) {
......@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
}
}
}
e_idx += deg_data[n_idx];
n_idx++;
e_idx += deg_data[r];
}
return cluster;
......@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
start = start.toType(pos.type());
end = end.toType(pos.type());
pos = pos - start.view({ 1, -1 });
pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({ 1, -1 });
auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({ 1, -1 });
cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1);
return cluster;
......
......@@ -25,5 +25,7 @@ def graclus_cluster(row, col, num_nodes):
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(graclus_cluster(row, col, 4))
print(col)
print('-----------------')
cluster = graclus_cluster(row, col, 4)
print(cluster)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment