Commit f0609836 authored by rusty1s's avatar rusty1s
Browse files

mt-metis support (experimental)

parent f577fcee
......@@ -4,6 +4,10 @@
#include <metis.h>
#endif
#ifdef WITH_MTMETIS
#include <mtmetis.h>
#endif
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
......@@ -19,14 +23,14 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
}
int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
int64_t ncon = 1;
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t *adjwgt = NULL;
if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
......@@ -42,3 +46,48 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with METIS support");
#endif
}
// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
// --partitions64bit
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive,
int64_t num_workers) {
#ifdef WITH_MTMETIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
if (optional_value.has_value()) {
CHECK_CPU(optional_value.value());
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().numel() == col.numel());
}
mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
mtmetis_vtx_type ncon = 1;
mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>();
mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>();
mtmetis_wgt_type *adjwgt = NULL;
if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>();
mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();
double *opts = mtmetis_init_options();
opts[MTMETIS_OPTION_NTHREADS] = num_workers;
if (recursive) {
MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data);
} else {
MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data);
}
return part;
#else
AT_ERROR("Not compiled with MTMETIS support");
#endif
}
......@@ -5,3 +5,8 @@
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive);
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive,
int64_t num_workers);
......@@ -21,5 +21,22 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::partition", &partition);
torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive,
int64_t num_workers) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return mt_partition_cpu(rowptr, col, optional_value, num_parts, recursive,
num_workers);
}
}
static auto registry = torch::RegisterOperators()
.op("torch_sparse::partition", &partition)
.op("torch_sparse::mt_partition", &mt_partition);
......@@ -11,6 +11,10 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive);
torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive);
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx);
......
......@@ -17,9 +17,8 @@ if os.getenv('FORCE_CPU', '0') == '1':
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = False
if os.getenv('WITH_METIS', '0') == '1':
WITH_METIS = True
WITH_METIS = True if os.getenv('WITH_METIS', '0') == '1' else False
WITH_MTMETIS = True if os.getenv('WITH_MTMETIS', '0') == '1' else False
def get_extensions():
......@@ -29,6 +28,13 @@ def get_extensions():
if WITH_METIS:
define_macros += [('WITH_METIS', None)]
libraries += ['metis']
if WITH_MTMETIS:
define_macros += [('WITH_MTMETIS', None)]
define_macros += [('MTMETIS_64BIT_VERTICES', None)]
define_macros += [('MTMETIS_64BIT_EDGES', None)]
define_macros += [('MTMETIS_64BIT_WEIGHTS', None)]
define_macros += [('MTMETIS_64BIT_PARTITIONS', None)]
libraries += ['mtmetis', 'wildriver']
extra_compile_args = {'cxx': []}
extra_link_args = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment