Commit df2ed804 authored by rusty1s's avatar rusty1s
Browse files

graclus done

parent dcd88f5a
#include <torch/torch.h> #include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col) { inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) {
/* at::Tensor perm; */ auto mask = row != col;
/* std::tie(row, perm) = row.sort(); */ row = row.masked_select(mask);
/* col = col.index_select(0, perm); */ col = col.masked_select(mask);
return {row, col};
/* TODO: randperm */ }
/* TODO: randperm_sort_row */
return { row, col };
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
} }
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) { inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes }); auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index)); return zero.scatter_add_(0, index, at::ones_like(index));
} }
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = randperm(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes); auto deg = degree(row, num_nodes);
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1); auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
auto *row_data = row.data<int64_t>(); auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>(); auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>(); auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>(); auto *cluster_data = cluster.data<int64_t>();
int64_t n_idx = 0, e_idx = 0, d_idx, r, c; int64_t e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) { while (e_idx < row.size(0)) {
r = row_data[e_idx]; r = row_data[e_idx];
if (cluster_data[r] < 0) { if (cluster_data[r] < 0) {
...@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { ...@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
} }
} }
} }
e_idx += deg_data[n_idx]; e_idx += deg_data[r];
n_idx++;
} }
return cluster; return cluster;
...@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en ...@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
start = start.toType(pos.type()); start = start.toType(pos.type());
end = end.toType(pos.type()); end = end.toType(pos.type());
pos = pos - start.view({ 1, -1 }); pos = pos - start.view({1, -1});
auto num_voxels = ((end - start) / size).toType(at::kLong); auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0); num_voxels = (num_voxels + 1).cumsum(0);
num_voxels -= num_voxels.data<int64_t>()[0]; num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1; num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({ 1, -1 }); auto cluster = pos / size.view({1, -1});
cluster = cluster.toType(at::kLong); cluster = cluster.toType(at::kLong);
cluster *= num_voxels.view({ 1, -1 }); cluster *= num_voxels.view({1, -1});
cluster = cluster.sum(1); cluster = cluster.sum(1);
return cluster; return cluster;
......
...@@ -25,5 +25,7 @@ def graclus_cluster(row, col, num_nodes): ...@@ -25,5 +25,7 @@ def graclus_cluster(row, col, num_nodes):
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2]) col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row) print(row)
print(col)
print(graclus_cluster(row, col, 4)) print('-----------------')
cluster = graclus_cluster(row, col, 4)
print(cluster)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment