Commit ab0cee58 authored by bwdeng20's avatar bwdeng20
Browse files

weighted undirected graph support

parent a1ae9033
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
#include "utils.h" #include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) { int64_t num_parts, torch::Tensor adjwgt,
bool recursive) {
#ifdef WITH_METIS #ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -17,15 +18,16 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -17,15 +18,16 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = adjwgt.data_ptr<int64_t>();
int64_t ncon = 1; int64_t ncon = 1;
int64_t objval = -1; int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
if (recursive) { if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL, METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data); &num_parts, NULL, NULL, NULL, &objval, part_data);
} else { } else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL, METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data); &num_parts, NULL, NULL, NULL, &objval, part_data);
} }
......
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive); int64_t num_parts,torch::Tensor adjwgt, bool recursive);
...@@ -10,7 +10,8 @@ def partition( ...@@ -10,7 +10,8 @@ def partition(
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu() rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, adjwgt=src.storage.value().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,adjwgt,
recursive) recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment