"vscode:/vscode.git/clone" did not exist on "d041dd504058ac6b0fde3eb767eb6844d8d577b8"
Commit ab0cee58 authored by bwdeng20's avatar bwdeng20
Browse files

weighted undirected graph support

parent a1ae9033
......@@ -7,7 +7,8 @@
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
int64_t num_parts, torch::Tensor adjwgt,
bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -17,15 +18,16 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = adjwgt.data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}
......
......@@ -3,4 +3,4 @@
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive);
int64_t num_parts,torch::Tensor adjwgt, bool recursive);
......@@ -10,7 +10,8 @@ def partition(
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
adjwgt=src.storage.value().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,adjwgt,
recursive)
cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment