cluster.cpp 1.85 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <torch/torch.h>


rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
at::Tensor degree(at::Tensor index, int64_t num_nodes) {
  auto one = at::ones_like(index);
  auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
  return zero.scatter_add_(0, index, one);
}


at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
  auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
  auto deg = degree(row, num_nodes);

  /* at::Tensor perm; */
  /* std::tie(row, perm) = row.sort(); */
  /* col = col.index_select(0, perm); */

  /* TODO: randperm */
  /* TODO: randperm_sort_row */

  auto *row_data = row.data<int64_t>();
  auto *col_data = col.data<int64_t>();
  auto *deg_data = deg.data<int64_t>();
  auto *cluster_data = cluster.data<int64_t>();

  int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
  while (e_idx < row.size(0)) {
    r = row_data[e_idx];
    if (cluster_data[r] < 0) {
      cluster_data[r] = r;
      for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
        c = col_data[e_idx + d_idx];
        if (cluster_data[c] < 0) {
          cluster_data[c] = r;
          break;
        }
      }
    }
    e_idx += deg_data[n_idx];
    n_idx++;
  }

  return cluster;
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
}


at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
  size = size.toType(pos.type());
  start = start.toType(pos.type());
  end = end.toType(pos.type());

  pos = pos - start.view({ 1, -1 });
  auto num_voxels = ((end - start) / size).toType(at::kLong);
  num_voxels = (num_voxels + 1).cumsum(0);
rusty1s's avatar
rusty1s committed
56
57
  num_voxels -= num_voxels.data<int64_t>()[0];
  num_voxels.data<int64_t>()[0] = 1;
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
68

  auto cluster = pos / size.view({ 1, -1 });
  cluster = cluster.toType(at::kLong);
  cluster *= num_voxels.view({ 1, -1 });
  cluster = cluster.sum(1);

  return cluster;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
rusty1s's avatar
rusty1s committed
69
70
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("grid", &grid, "Grid (CPU)");
rusty1s's avatar
rusty1s committed
71
}