Commit 523e86a3 authored by bowendeng's avatar bowendeng
Browse files

csrc: partition_cpu supports weighted graph

parent 203d69f2
...@@ -7,17 +7,24 @@ ...@@ -7,17 +7,24 @@
#include "utils.h" #include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt, bool recursive) { torch::optional<torch::Tensor> edge_wgt, bool recursive) {
#ifdef WITH_METIS #ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1; int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = edge_wgt.data_ptr<int64_t>();
int64_t * adjwgt;
if (edge_wgt.has_value()){
adjwgt=edge_wgt.value().data_ptr<int64_t>();
adjwgt=(idx_t*) adjwgt;
}else{
adjwgt=nullptr;
}
int64_t ncon = 1; int64_t ncon = 1;
int64_t objval = -1; int64_t objval = -1;
......
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts,torch::Tensor edge_wgt, int64_t num_parts,torch::optional<torch::Tensor> edge_wgt,
bool recursive); bool recursive);
...@@ -7,7 +7,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; } ...@@ -7,7 +7,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif #endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt,bool recursive) { torch::optional<torch::Tensor> edge_wgt,bool recursive) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
......
...@@ -11,7 +11,7 @@ def metis_wgt(x): ...@@ -11,7 +11,7 @@ def metis_wgt(x):
diff = t1 - t2 diff = t1 - t2
diff = diff[diff != 0] diff = diff[diff != 0]
if len(diff) == 0: if len(diff) == 0:
return torch.ones(x.shape, dtype=torch.long) return None
res = diff.abs().min() res = diff.abs().min()
bod = x.max() - x.min() bod = x.max() - x.min()
scale = (res / bod).item() scale = (res / bod).item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment