Commit 24a770f3 authored by rusty1s's avatar rusty1s
Browse files

graclus cpu implementation

parent f0a4d21d
#include <torch/torch.h>
at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
return row;
at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto one = at::ones_like(index);
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
return zero.scatter_add_(0, index, one);
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
auto deg = degree(row, num_nodes);
/* at::Tensor perm; */
/* std::tie(row, perm) = row.sort(); */
/* col = col.index_select(0, perm); */
/* TODO: randperm */
/* TODO: randperm_sort_row */
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>();
int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) {
r = row_data[e_idx];
if (cluster_data[r] < 0) {
cluster_data[r] = r;
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
c = col_data[e_idx + d_idx];
if (cluster_data[c] < 0) {
cluster_data[c] = r;
break;
}
}
}
e_idx += deg_data[n_idx];
n_idx++;
}
return cluster;
}
......@@ -14,8 +53,8 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
pos = pos - start.view({ 1, -1 });
auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0);
num_voxels = num_voxels - num_voxels[0];
num_voxels[0] = 1;
num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({ 1, -1 });
cluster = cluster.toType(at::kLong);
......
......@@ -9,11 +9,21 @@ def grid_cluster(pos, size, start=None, end=None):
return cluster_cpu.grid(pos, size, start, end)
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
size = torch.tensor([2, 2])
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
print('pos', pos.tolist())
print('size', size.tolist())
cluster = grid_cluster(pos, size)
print('result', cluster.tolist(), cluster.dtype)
def graclus_cluster(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes)
# pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
# size = torch.tensor([2, 2])
# start = torch.tensor([0, 0])
# end = torch.tensor([7, 7])
# print('pos', pos.tolist())
# print('size', size.tolist())
# cluster = grid_cluster(pos, size)
# print('result', cluster.tolist(), cluster.dtype)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(graclus_cluster(row, col, 4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment