graclus_kernel.cu 1.12 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
4
5
6
7
#include "coloring.cuh"
#include "proposal.cuh"
#include "response.cuh"
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
8
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
9
10
11
12
  std::tie(row, col) = remove_self_loops(row, col);
  std::tie(row, col) = rand(row, col);
  std::tie(row, col) = to_csr(row, col, num_nodes);

rusty1s's avatar
rusty1s committed
13
  auto cluster = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
  auto proposal = at::full(num_nodes, -1, row.options());

  while (!colorize(cluster)) {
    propose(cluster, proposal, row, col);
    respond(cluster, proposal, row, col);
  }

rusty1s's avatar
rusty1s committed
21
22
23
24
25
  return cluster;
}

at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
                                 at::Tensor weight, int64_t num_nodes) {
rusty1s's avatar
rusty1s committed
26
27
  std::tie(row, col, weight) = remove_self_loops(row, col, weight);
  std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
  auto cluster = at::full(num_nodes, -1, row.options());
  auto proposal = at::full(num_nodes, -1, row.options());
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
33
34
35
  while (!colorize(cluster)) {
    propose(cluster, proposal, row, col, weight);
    respond(cluster, proposal, row, col, weight);
  }
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
  return cluster;
}