#include "metis_wrapper_cpu.h" #include #include "utils.h" torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col, int64_t num_parts) { CHECK_CPU(rowptr); CHECK_CPU(col); int64_t nvtxs = rowptr.numel() - 1; auto *xadj = rowptr.data_ptr(); auto *adjncy = col.data_ptr(); int64_t ncon = 1; int64_t objval = -1; auto part = torch::empty(nvtxs, rowptr.options()); auto part_data = part.data_ptr(); METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL, &num_parts, NULL, NULL, NULL, &objval, part_data); return part; }