cluster.cpp 2.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/torch.h>

rusty1s's avatar
linting  
rusty1s committed
3
4
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
                                                            at::Tensor col) {
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
  auto mask = row != col;
  row = row.masked_select(mask);
  col = col.masked_select(mask);
  return {row, col};
}

rusty1s's avatar
linting  
rusty1s committed
11
12
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  // Randomly reorder row and column indices.
  auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
  row = row.index_select(0, perm);
  col = col.index_select(0, perm);

  // Randomly swap row values.
  auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
  row = node_rid.index_select(0, row);

  // Sort row and column indices row-wise.
  std::tie(row, perm) = row.sort();
  col = col.index_select(0, perm);

  // Revert row value swaps.
  row = std::get<1>(node_rid.sort()).index_select(0, row);

  return {row, col};
rusty1s's avatar
rusty1s committed
30
31
32
}

inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
33
  auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
rusty1s's avatar
rusty1s committed
34
35
36
37
  return zero.scatter_add_(0, index, at::ones_like(index));
}

at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
38
39
40
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = randperm(row, col, num_nodes);

rusty1s's avatar
rusty1s committed
41
  auto deg = degree(row, num_nodes);
rusty1s's avatar
rusty1s committed
42
  auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48

  auto *row_data = row.data<int64_t>();
  auto *col_data = col.data<int64_t>();
  auto *deg_data = deg.data<int64_t>();
  auto *cluster_data = cluster.data<int64_t>();

rusty1s's avatar
rusty1s committed
49
  int64_t e_idx = 0, d_idx, r, c;
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
  while (e_idx < row.size(0)) {
    r = row_data[e_idx];
    if (cluster_data[r] < 0) {
      cluster_data[r] = r;
      for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
        c = col_data[e_idx + d_idx];
        if (cluster_data[c] < 0) {
rusty1s's avatar
rusty1s committed
57
58
          cluster_data[r] = std::min(r, c);
          cluster_data[c] = std::min(r, c);
rusty1s's avatar
rusty1s committed
59
60
61
62
          break;
        }
      }
    }
rusty1s's avatar
rusty1s committed
63
    e_idx += deg_data[r];
rusty1s's avatar
rusty1s committed
64
65
66
  }

  return cluster;
rusty1s's avatar
rusty1s committed
67
68
}

rusty1s's avatar
linting  
rusty1s committed
69
70
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
                at::Tensor end) {
rusty1s's avatar
rusty1s committed
71
72
73
74
  size = size.toType(pos.type());
  start = start.toType(pos.type());
  end = end.toType(pos.type());

rusty1s's avatar
rusty1s committed
75
  pos = pos - start.view({1, -1});
rusty1s's avatar
rusty1s committed
76
77
  auto num_voxels = ((end - start) / size).toType(at::kLong);
  num_voxels = (num_voxels + 1).cumsum(0);
rusty1s's avatar
rusty1s committed
78
79
  num_voxels -= num_voxels.data<int64_t>()[0];
  num_voxels.data<int64_t>()[0] = 1;
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
  auto cluster = pos / size.view({1, -1});
rusty1s's avatar
rusty1s committed
82
  cluster = cluster.toType(at::kLong);
rusty1s's avatar
rusty1s committed
83
  cluster *= num_voxels.view({1, -1});
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
  cluster = cluster.sum(1);

  return cluster;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
rusty1s's avatar
rusty1s committed
90
91
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("grid", &grid, "Grid (CPU)");
rusty1s's avatar
rusty1s committed
92
}