Commit 49e0a6d6 authored by bowendeng's avatar bowendeng
Browse files

fix for unweighted graph partition

parent 7b83a608
......@@ -7,7 +7,7 @@
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt, bool recursive) {
torch::optional<torch::Tensor> edge_wgt, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -17,7 +17,12 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = edge_wgt.data_ptr<int64_t>();
if (edge_wgt==nullptr){
adjwgt = nullptr
}else{
auto *adjwgt = edge_wgt.data_ptr<int64_t>();
}
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
......@@ -34,4 +39,4 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_
#else
AT_ERROR("Not compiled with METIS support");
#endif
}
}
\ No newline at end of file
......@@ -3,4 +3,5 @@
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts,torch::Tensor adjwgt, bool recursive);
int64_t num_parts,torch::optional<torch::Tensor> edge_wgt,
bool recursive);
#include <Python.h>
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32
......@@ -8,7 +7,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt,bool recursive) {
torch::optional<torch::Tensor> edge_wgt,bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
......
......@@ -7,6 +7,8 @@ from torch_sparse.utils import cartesian1d
def metis_wgt(x):
if len(x.unique()) == 1:
return None
t1, t2 = cartesian1d(x, x)
diff = t1 - t2
diff = diff[diff != 0]
......@@ -15,14 +17,14 @@ def metis_wgt(x):
scale = (res / bod).item()
tick, arange = scale.as_integer_ratio()
x_ratio = (x - x.min()) / bod
return (x_ratio * arange + tick).long(), tick, arange
return (x_ratio * arange + tick).long()
def partition(src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
edge_wgt = src.storage.value().cpu()
edge_wgt = metis_wgt(edge_wgt)[0]
edge_wgt = metis_wgt(edge_wgt)
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, edge_wgt,
recursive)
cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment