cluster.cpp 526 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <torch/torch.h>

rusty1s's avatar
rusty1s committed
3
4
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
                at::Tensor end);
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes);

at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
                            int num_nodes);
rusty1s's avatar
rusty1s committed
10
11
12

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("grid", &grid, "Grid (CUDA)");
rusty1s's avatar
rusty1s committed
13
14
  m.def("graclus", &graclus, "Graclus (CUDA)");
  m.def("weighted_graclus", &weighted_graclus, "Weightes Graclus (CUDA)");
rusty1s's avatar
rusty1s committed
15
}