Commit 523e86a3 authored by bowendeng's avatar bowendeng
Browse files

csrc: partition_cpu supports weighted graph

parent 203d69f2
......@@ -7,17 +7,24 @@
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt, bool recursive) {
torch::optional<torch::Tensor> edge_wgt, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = edge_wgt.data_ptr<int64_t>();
int64_t * adjwgt;
if (edge_wgt.has_value()){
adjwgt=edge_wgt.value().data_ptr<int64_t>();
adjwgt=(idx_t*) adjwgt;
}else{
adjwgt=nullptr;
}
int64_t ncon = 1;
int64_t objval = -1;
......
......@@ -3,5 +3,5 @@
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts,torch::Tensor edge_wgt,
int64_t num_parts,torch::optional<torch::Tensor> edge_wgt,
bool recursive);
......@@ -7,7 +7,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt,bool recursive) {
torch::optional<torch::Tensor> edge_wgt,bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
......
......@@ -11,7 +11,7 @@ def metis_wgt(x):
diff = t1 - t2
diff = diff[diff != 0]
if len(diff) == 0:
return torch.ones(x.shape, dtype=torch.long)
return None
res = diff.abs().min()
bod = x.max() - x.min()
scale = (res / bod).item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment