cluster.cpp 2.65 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <torch/torch.h>


rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) {
  auto mask = row != col;
  row = row.masked_select(mask);
  col = col.masked_select(mask);
  return {row, col};
}


inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
  // Randomly reorder row and column indices.
  auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
  row = row.index_select(0, perm);
  col = col.index_select(0, perm);

  // Randomly swap row values.
  auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
  row = node_rid.index_select(0, row);

  // Sort row and column indices row-wise.
  std::tie(row, perm) = row.sort();
  col = col.index_select(0, perm);

  // Revert row value swaps.
  row = std::get<1>(node_rid.sort()).index_select(0, row);

  return {row, col};
rusty1s's avatar
rusty1s committed
30
31
32
33
}


inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
34
  auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
rusty1s's avatar
rusty1s committed
35
36
37
38
39
  return zero.scatter_add_(0, index, at::ones_like(index));
}


at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
40
41
42
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = randperm(row, col, num_nodes);

rusty1s's avatar
rusty1s committed
43
  auto deg = degree(row, num_nodes);
rusty1s's avatar
rusty1s committed
44
  auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50

  auto *row_data = row.data<int64_t>();
  auto *col_data = col.data<int64_t>();
  auto *deg_data = deg.data<int64_t>();
  auto *cluster_data = cluster.data<int64_t>();

rusty1s's avatar
rusty1s committed
51
  int64_t e_idx = 0, d_idx, r, c;
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
  while (e_idx < row.size(0)) {
    r = row_data[e_idx];
    if (cluster_data[r] < 0) {
      cluster_data[r] = r;
      for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
        c = col_data[e_idx + d_idx];
        if (cluster_data[c] < 0) {
rusty1s's avatar
rusty1s committed
59
60
          cluster_data[r] = std::min(r, c);
          cluster_data[c] = std::min(r, c);
rusty1s's avatar
rusty1s committed
61
62
63
64
          break;
        }
      }
    }
rusty1s's avatar
rusty1s committed
65
    e_idx += deg_data[r];
rusty1s's avatar
rusty1s committed
66
67
68
  }

  return cluster;
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
}


at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
  size = size.toType(pos.type());
  start = start.toType(pos.type());
  end = end.toType(pos.type());

rusty1s's avatar
rusty1s committed
77
  pos = pos - start.view({1, -1});
rusty1s's avatar
rusty1s committed
78
79
  auto num_voxels = ((end - start) / size).toType(at::kLong);
  num_voxels = (num_voxels + 1).cumsum(0);
rusty1s's avatar
rusty1s committed
80
81
  num_voxels -= num_voxels.data<int64_t>()[0];
  num_voxels.data<int64_t>()[0] = 1;
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
  auto cluster = pos / size.view({1, -1});
rusty1s's avatar
rusty1s committed
84
  cluster = cluster.toType(at::kLong);
rusty1s's avatar
rusty1s committed
85
  cluster *= num_voxels.view({1, -1});
rusty1s's avatar
rusty1s committed
86
87
88
89
90
91
92
  cluster = cluster.sum(1);

  return cluster;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
rusty1s's avatar
rusty1s committed
93
94
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("grid", &grid, "Grid (CPU)");
rusty1s's avatar
rusty1s committed
95
}