graclus.cpp 2.17 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/torch.h>

rusty1s's avatar
linting  
rusty1s committed
3
4
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
                                                            at::Tensor col) {
rusty1s's avatar
rusty1s committed
5
  auto mask = row != col;
rusty1s's avatar
rusty1s committed
6
  return {row.masked_select(mask), col.masked_select(mask)};
rusty1s's avatar
rusty1s committed
7
8
}

rusty1s's avatar
linting  
rusty1s committed
9
10
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  // Randomly reorder row and column indices.
  auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
  row = row.index_select(0, perm);
  col = col.index_select(0, perm);

  // Randomly swap row values.
  auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
  row = node_rid.index_select(0, row);

  // Sort row and column indices row-wise.
  std::tie(row, perm) = row.sort();
  col = col.index_select(0, perm);

  // Revert row value swaps.
  row = std::get<1>(node_rid.sort()).index_select(0, row);

  return {row, col};
rusty1s's avatar
rusty1s committed
28
29
}

rusty1s's avatar
rusty1s committed
30
31
32
33
34
inline at::Tensor degree(at::Tensor index, int64_t num_nodes,
                         at::ScalarType scalar_type) {
  auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0);
  auto one = at::full(zero.type(), {index.size(0)}, 1);
  return zero.scatter_add_(0, index, one);
rusty1s's avatar
rusty1s committed
35
36
37
}

at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
38
39
40
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = randperm(row, col, num_nodes);

rusty1s's avatar
rusty1s committed
41
42
  auto cluster = at::full(row.type(), {num_nodes}, -1);
  auto deg = degree(row, num_nodes, row.type().scalarType());
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48

  auto *row_data = row.data<int64_t>();
  auto *col_data = col.data<int64_t>();
  auto *deg_data = deg.data<int64_t>();
  auto *cluster_data = cluster.data<int64_t>();

rusty1s's avatar
rusty1s committed
49
  int64_t e_idx = 0, d_idx, r, c;
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
  while (e_idx < row.size(0)) {
    r = row_data[e_idx];
    if (cluster_data[r] < 0) {
      cluster_data[r] = r;
      for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
        c = col_data[e_idx + d_idx];
        if (cluster_data[c] < 0) {
rusty1s's avatar
rusty1s committed
57
58
          cluster_data[r] = std::min(r, c);
          cluster_data[c] = std::min(r, c);
rusty1s's avatar
rusty1s committed
59
60
61
62
          break;
        }
      }
    }
rusty1s's avatar
rusty1s committed
63
    e_idx += deg_data[r];
rusty1s's avatar
rusty1s committed
64
65
66
  }

  return cluster;
rusty1s's avatar
rusty1s committed
67
68
69
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
rusty1s's avatar
rusty1s committed
70
  m.def("graclus", &graclus, "Graclus (CPU)");
rusty1s's avatar
rusty1s committed
71
}