Commit 9c3519b4 authored by rusty1s's avatar rusty1s
Browse files

update

parent 523e86a3
...@@ -6,26 +6,25 @@ ...@@ -6,26 +6,25 @@
#include "utils.h" #include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> edge_wgt, bool recursive) { torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS #ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
if (optional_value.has_value()) {
CHECK_CPU(optional_value.value());
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().numel() == col.numel());
}
int64_t nvtxs = rowptr.numel() - 1; int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
int64_t *adjwgt = NULL;
int64_t * adjwgt; if (optional_value.has_value())
if (edge_wgt.has_value()){ adjwgt = optional_value.value().data_ptr<int64_t>();
adjwgt=edge_wgt.value().data_ptr<int64_t>();
adjwgt=(idx_t*) adjwgt;
}else{
adjwgt=nullptr;
}
int64_t ncon = 1; int64_t ncon = 1;
int64_t objval = -1; int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
......
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts,torch::optional<torch::Tensor> edge_wgt, torch::optional<torch::Tensor> optional_value,
bool recursive); int64_t num_parts, bool recursive);
#include "cpu/metis_cpu.h"
#include <Python.h> #include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__metis(void) { return NULL; } PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif #endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> edge_wgt,bool recursive) { torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
...@@ -15,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_part ...@@ -15,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_part
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return partition_cpu(rowptr, col,num_parts, edge_wgt, recursive); return partition_cpu(rowptr, col, optional_value, num_parts, recursive);
} }
} }
......
...@@ -3,10 +3,15 @@ from typing import Tuple ...@@ -3,10 +3,15 @@ from typing import Tuple
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute from torch_sparse.permute import permute
from torch_sparse.utils import cartesian1d
def metis_wgt(x): def cartesian1d(x, y):
a1, a2 = torch.meshgrid([x, y])
coos = torch.stack([a1, a2]).T.reshape(-1, 2)
return coos.split(1, dim=1)
def metis_weight(x):
t1, t2 = cartesian1d(x, x) t1, t2 = cartesian1d(x, x)
diff = t1 - t2 diff = t1 - t2
diff = diff[diff != 0] diff = diff[diff != 0]
...@@ -22,10 +27,12 @@ def metis_wgt(x): ...@@ -22,10 +27,12 @@ def metis_wgt(x):
def partition(src: SparseTensor, num_parts: int, recursive: bool = False def partition(src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu() rowptr, col, value = src.csr()
edge_wgt = src.storage.value().cpu() rowptr, col = rowptr.cpu(), col.cpu()
edge_wgt = metis_wgt(edge_wgt) if value is not None and value.dim() == 1:
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, edge_wgt, value = value.detach().cpu()
value = metis_weight(value)
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive) recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
......
from typing import Any from typing import Any
import torch
try: try:
from typing_extensions import Final # noqa from typing_extensions import Final # noqa
...@@ -9,9 +8,3 @@ except ImportError: ...@@ -9,9 +8,3 @@ except ImportError:
def is_scalar(other: Any) -> bool: def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float) return isinstance(other, int) or isinstance(other, float)
def cartesian1d(x, y):
a1, a2 = torch.meshgrid([x, y])
coos = torch.stack([a1, a2]).T.reshape(-1, 2)
return coos.split(1, dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment