graclus.cpp 2.24 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/torch.h>

rusty1s's avatar
rusty1s committed
3
4
#include "utils.h"

rusty1s's avatar
rusty1s committed
5
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
6
7
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = rand(row, col);
rusty1s's avatar
rusty1s committed
8
  std::tie(row, col) = to_csr(row, col, num_nodes);
rusty1s's avatar
rusty1s committed
9
10
11
12
  auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();

  auto perm = randperm(num_nodes);
  auto perm_data = perm.data<int64_t>();
rusty1s's avatar
rusty1s committed
13
14

  auto cluster = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
  auto cluster_data = cluster.data<int64_t>();

  for (int64_t i = 0; i < num_nodes; i++) {
    auto u = perm_data[i];

    if (cluster_data[u] >= 0)
      continue;

    cluster_data[u] = u;
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
    for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
      auto v = col_data[j];

rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
      if (cluster_data[v] >= 0)
        continue;

      cluster_data[u] = std::min(u, v);
      cluster_data[v] = std::min(u, v);
      break;
rusty1s's avatar
rusty1s committed
34
    }
rusty1s's avatar
rusty1s committed
35
  }
rusty1s's avatar
rusty1s committed
36
37
38
39
40
41

  return cluster;
}

at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
                            int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
42
43
  std::tie(row, col, weight) = remove_self_loops(row, col, weight);
  std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
rusty1s's avatar
rusty1s committed
44
45
46
47
48
  auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();

  auto perm = randperm(num_nodes);
  auto perm_data = perm.data<int64_t>();

rusty1s's avatar
rusty1s committed
49
  auto cluster = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
59
60
  auto cluster_data = cluster.data<int64_t>();

  AT_DISPATCH_ALL_TYPES(weight.type(), "weighted_graclus", [&] {
    auto weight_data = weight.data<scalar_t>();

    for (int64_t i = 0; i < num_nodes; i++) {
      auto u = perm_data[i];

      if (cluster_data[u] >= 0)
        continue;

rusty1s's avatar
rusty1s committed
61
      int64_t v_max = u;
rusty1s's avatar
rusty1s committed
62
63
      scalar_t w_max = 0;

rusty1s's avatar
rusty1s committed
64
65
66
      for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
        auto v = col_data[j];

rusty1s's avatar
rusty1s committed
67
68
69
        if (cluster_data[v] >= 0)
          continue;

rusty1s's avatar
rusty1s committed
70
        if (weight_data[j] >= w_max) {
rusty1s's avatar
rusty1s committed
71
          v_max = v;
rusty1s's avatar
rusty1s committed
72
          w_max = weight_data[j];
rusty1s's avatar
rusty1s committed
73
        }
rusty1s's avatar
rusty1s committed
74
      }
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80

      cluster_data[u] = std::min(u, v_max);
      cluster_data[v_max] = std::min(u, v_max);
    }
  });

rusty1s's avatar
rusty1s committed
81
82
83
84
85
86
87
  return cluster;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("weighted_graclus", &weighted_graclus, "Weighted Graclus (CPU)");
}