"vscode:/vscode.git/clone" did not exist on "7b56e494bed2246daaaabbdbe12462d91010135f"
graclus.cpp 1.46 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <torch/torch.h>

// #include "../include/degree.cpp"
// #include "../include/loop.cpp"
// #include "../include/perm.cpp"

at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
  // std::tie(row, col) = remove_self_loops(row, col);
  // std::tie(row, col) = randperm(row, col, num_nodes);
  // auto deg = degree(row, num_nodes, row.type().scalarType());

  auto cluster = at::full(num_nodes, -1, row.options());

  // auto *row_data = row.data<int64_t>();
  // auto *col_data = col.data<int64_t>();
  // auto *deg_data = deg.data<int64_t>();
  // auto *cluster_data = cluster.data<int64_t>();

  // int64_t e_idx = 0, d_idx, r, c;
  // while (e_idx < row.size(0)) {
  //   r = row_data[e_idx];
  //   if (cluster_data[r] < 0) {
  //     cluster_data[r] = r;
  //     for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
  //       c = col_data[e_idx + d_idx];
  //       if (cluster_data[c] < 0) {
  //         cluster_data[r] = std::min(r, c);
  //         cluster_data[c] = std::min(r, c);
  //         break;
  //       }
  //     }
  //   }
  //   e_idx += deg_data[r];
  // }

  return cluster;
}

at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
                            int64_t num_nodes) {
  auto cluster = at::full(num_nodes, -1, row.options());
  return cluster;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("graclus", &graclus, "Graclus (CPU)");
  m.def("weighted_graclus", &weighted_graclus, "Weighted Graclus (CPU)");
}