Commit 24a770f3 authored by rusty1s's avatar rusty1s
Browse files

graclus cpu implementation

parent f0a4d21d
#include <torch/torch.h> #include <torch/torch.h>
at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) { at::Tensor degree(at::Tensor index, int64_t num_nodes) {
return row; auto one = at::ones_like(index);
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
return zero.scatter_add_(0, index, one);
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
auto deg = degree(row, num_nodes);
/* at::Tensor perm; */
/* std::tie(row, perm) = row.sort(); */
/* col = col.index_select(0, perm); */
/* TODO: randperm */
/* TODO: randperm_sort_row */
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
auto *deg_data = deg.data<int64_t>();
auto *cluster_data = cluster.data<int64_t>();
int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
while (e_idx < row.size(0)) {
r = row_data[e_idx];
if (cluster_data[r] < 0) {
cluster_data[r] = r;
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
c = col_data[e_idx + d_idx];
if (cluster_data[c] < 0) {
cluster_data[c] = r;
break;
}
}
}
e_idx += deg_data[n_idx];
n_idx++;
}
return cluster;
} }
...@@ -14,8 +53,8 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en ...@@ -14,8 +53,8 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
pos = pos - start.view({ 1, -1 }); pos = pos - start.view({ 1, -1 });
auto num_voxels = ((end - start) / size).toType(at::kLong); auto num_voxels = ((end - start) / size).toType(at::kLong);
num_voxels = (num_voxels + 1).cumsum(0); num_voxels = (num_voxels + 1).cumsum(0);
num_voxels = num_voxels - num_voxels[0]; num_voxels -= num_voxels.data<int64_t>()[0];
num_voxels[0] = 1; num_voxels.data<int64_t>()[0] = 1;
auto cluster = pos / size.view({ 1, -1 }); auto cluster = pos / size.view({ 1, -1 });
cluster = cluster.toType(at::kLong); cluster = cluster.toType(at::kLong);
......
...@@ -9,11 +9,21 @@ def grid_cluster(pos, size, start=None, end=None): ...@@ -9,11 +9,21 @@ def grid_cluster(pos, size, start=None, end=None):
return cluster_cpu.grid(pos, size, start, end) return cluster_cpu.grid(pos, size, start, end)
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]]) def graclus_cluster(row, col, num_nodes):
size = torch.tensor([2, 2]) return cluster_cpu.graclus(row, col, num_nodes)
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
print('pos', pos.tolist()) # pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
print('size', size.tolist()) # size = torch.tensor([2, 2])
cluster = grid_cluster(pos, size) # start = torch.tensor([0, 0])
print('result', cluster.tolist(), cluster.dtype) # end = torch.tensor([7, 7])
# print('pos', pos.tolist())
# print('size', size.tolist())
# cluster = grid_cluster(pos, size)
# print('result', cluster.tolist(), cluster.dtype)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(graclus_cluster(row, col, 4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment