graclus.cpp 1007 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/torch.h>

rusty1s's avatar
rusty1s committed
3
4
5
#include "degree.cpp"
#include "loop.cpp"
#include "perm.cpp"
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
to int  
rusty1s committed
7
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
rusty1s's avatar
rusty1s committed
8
9
10
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = randperm(row, col, num_nodes);

rusty1s's avatar
rusty1s committed
11
12
  auto cluster = at::full(row.type(), {num_nodes}, -1);
  auto deg = degree(row, num_nodes, row.type().scalarType());
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18

  auto *row_data = row.data<int64_t>();
  auto *col_data = col.data<int64_t>();
  auto *deg_data = deg.data<int64_t>();
  auto *cluster_data = cluster.data<int64_t>();

rusty1s's avatar
rusty1s committed
19
  int64_t e_idx = 0, d_idx, r, c;
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
  while (e_idx < row.size(0)) {
    r = row_data[e_idx];
    if (cluster_data[r] < 0) {
      cluster_data[r] = r;
      for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
        c = col_data[e_idx + d_idx];
        if (cluster_data[c] < 0) {
rusty1s's avatar
rusty1s committed
27
28
          cluster_data[r] = std::min(r, c);
          cluster_data[c] = std::min(r, c);
rusty1s's avatar
rusty1s committed
29
30
31
32
          break;
        }
      }
    }
rusty1s's avatar
rusty1s committed
33
    e_idx += deg_data[r];
rusty1s's avatar
rusty1s committed
34
35
36
  }

  return cluster;
rusty1s's avatar
rusty1s committed
37
}