Commit a06899bb authored by rusty1s's avatar rusty1s
Browse files

recursive

parent 45d29d1a
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#include "utils.h" #include "utils.h"
torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts) { int64_t num_parts, bool recursive) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -17,8 +17,13 @@ torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -17,8 +17,13 @@ torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col,
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL, &num_parts, if (recursive) {
NULL, NULL, NULL, &objval, part_data); METIS PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}
return part; return part;
} }
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts); int64_t num_parts, bool recursive);
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
PyMODINIT_FUNC PyInit__metis_wrapper(void) { return NULL; } PyMODINIT_FUNC PyInit__metis_wrapper(void) { return NULL; }
#endif #endif
torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts) { int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
...@@ -18,9 +18,9 @@ torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col, ...@@ -18,9 +18,9 @@ torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return partition_kway_cpu(rowptr, col, num_parts); return partition_kway_cpu(rowptr, col, num_parts, recursive);
} }
} }
static auto registry = torch::RegisterOperators().op( static auto registry =
"torch_sparse::partition_kway", &partition_kway); torch::RegisterOperators().op("torch_sparse::partition", &partition);
...@@ -5,12 +5,13 @@ from torch_sparse.tensor import SparseTensor ...@@ -5,12 +5,13 @@ from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute from torch_sparse.permute import permute
def partition_kway( def partition(
src: SparseTensor, src: SparseTensor, num_parts: int, recursive: bool = False
num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu() rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition_kway(rowptr, col, num_parts) cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
cluster, perm = cluster.sort() cluster, perm = cluster.sort()
...@@ -20,5 +21,4 @@ def partition_kway( ...@@ -20,5 +21,4 @@ def partition_kway(
return out, partptr, perm return out, partptr, perm
SparseTensor.partition_kway = lambda self, num_parts: partition_kway( SparseTensor.partition = lambda self, num_parts: partition(self, num_parts)
self, num_parts)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment