graclus.cpp 2.34 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#include "compat.h"
rusty1s's avatar
rusty1s committed
4
5
#include "utils.h"

rusty1s's avatar
rusty1s committed
6
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
7
8
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = rand(row, col);
rusty1s's avatar
rusty1s committed
9
  std::tie(row, col) = to_csr(row, col, num_nodes);
rusty1s's avatar
rusty1s committed
10
  auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
  auto perm = at::randperm(num_nodes, row.options());
rusty1s's avatar
rusty1s committed
13
  auto perm_data = perm.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
14
15

  auto cluster = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
16
  auto cluster_data = cluster.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24

  for (int64_t i = 0; i < num_nodes; i++) {
    auto u = perm_data[i];

    if (cluster_data[u] >= 0)
      continue;

    cluster_data[u] = u;
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
    for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
      auto v = col_data[j];

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
      if (cluster_data[v] >= 0)
        continue;

      cluster_data[u] = std::min(u, v);
      cluster_data[v] = std::min(u, v);
      break;
rusty1s's avatar
rusty1s committed
35
    }
rusty1s's avatar
rusty1s committed
36
  }
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42

  return cluster;
}

at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
                            int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
43
44
  std::tie(row, col, weight) = remove_self_loops(row, col, weight);
  std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
rusty1s's avatar
rusty1s committed
45
  auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
  auto perm = at::randperm(num_nodes, row.options());
rusty1s's avatar
rusty1s committed
48
  auto perm_data = perm.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
  auto cluster = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
51
  auto cluster_data = cluster.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
  AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
rusty1s's avatar
rusty1s committed
54
    auto weight_data = weight.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
55
56
57
58
59
60
61

    for (int64_t i = 0; i < num_nodes; i++) {
      auto u = perm_data[i];

      if (cluster_data[u] >= 0)
        continue;

rusty1s's avatar
rusty1s committed
62
      int64_t v_max = u;
rusty1s's avatar
rusty1s committed
63
64
      scalar_t w_max = 0;

rusty1s's avatar
rusty1s committed
65
66
67
      for (int64_t j = row_data[u]; j < row_data[u + 1]; j++) {
        auto v = col_data[j];

rusty1s's avatar
rusty1s committed
68
69
70
        if (cluster_data[v] >= 0)
          continue;

rusty1s's avatar
rusty1s committed
71
        if (weight_data[j] >= w_max) {
rusty1s's avatar
rusty1s committed
72
          v_max = v;
rusty1s's avatar
rusty1s committed
73
          w_max = weight_data[j];
rusty1s's avatar
rusty1s committed
74
        }
rusty1s's avatar
rusty1s committed
75
      }
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81

      cluster_data[u] = std::min(u, v_max);
      cluster_data[v_max] = std::min(u, v_max);
    }
  });

rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
88
  return cluster;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("weighted_graclus", &weighted_graclus, "Weighted Graclus (CPU)");
}