Commit 1e959269 authored by bowendeng's avatar bowendeng
Browse files

fix for unweighted graph partition

parent 49e0a6d6
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "utils.h" #include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::optional<torch::Tensor> edge_wgt, bool recursive) { torch::Tensor edge_wgt, bool recursive) {
#ifdef WITH_METIS #ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -17,11 +17,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_ ...@@ -17,11 +17,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
if (edge_wgt==nullptr){
adjwgt = nullptr
}else{
auto *adjwgt = edge_wgt.data_ptr<int64_t>(); auto *adjwgt = edge_wgt.data_ptr<int64_t>();
}
int64_t ncon = 1; int64_t ncon = 1;
int64_t objval = -1; int64_t objval = -1;
......
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts,torch::optional<torch::Tensor> edge_wgt, int64_t num_parts,torch::Tensor edge_wgt,
bool recursive); bool recursive);
...@@ -7,11 +7,11 @@ from torch_sparse.utils import cartesian1d ...@@ -7,11 +7,11 @@ from torch_sparse.utils import cartesian1d
def metis_wgt(x): def metis_wgt(x):
if len(x.unique()) == 1:
return None
t1, t2 = cartesian1d(x, x) t1, t2 = cartesian1d(x, x)
diff = t1 - t2 diff = t1 - t2
diff = diff[diff != 0] diff = diff[diff != 0]
if len(diff) == 0:
return x.long()
res = diff.abs().min() res = diff.abs().min()
bod = x.max() - x.min() bod = x.max() - x.min()
scale = (res / bod).item() scale = (res / bod).item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment