Commit c30d2e13 authored by bowendeng's avatar bowendeng
Browse files

utils to support METIS with adj_wgt

parent ab0cee58
......@@ -6,9 +6,8 @@
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, torch::Tensor adjwgt,
bool recursive) {
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -18,7 +17,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
auto *adjwgt = adjwgt.data_ptr<int64_t>();
auto *adjwgt = edge_wgt.data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
......
......@@ -7,8 +7,8 @@
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_parts,
torch::Tensor edge_wgt,bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
......@@ -16,7 +16,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, num_parts, recursive);
return partition_cpu(rowptr, col,num_parts, edge_wgt, recursive);
}
}
......
......@@ -80,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_sparse',
version='0.6.0',
version='0.6.1',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_sparse',
......
......@@ -3,15 +3,28 @@ from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
from torch_sparse.utils import cartesian1d
def metis_wgt(x):
t1, t2 = cartesian1d(x, x)
diff = t1 - t2
diff = diff[diff != 0]
res = diff.abs().min()
bod = x.max() - x.min()
scale = (res / bod).item()
tick, arange = scale.as_integer_ratio()
x_ratio = (x - x.min()) / bod
return (x_ratio * arange + tick).long(), tick, arange
def partition(
src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
adjwgt=src.storage.value().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,adjwgt,
edge_wgt = src.storage.value().cpu()
edge_wgt = metis_wgt(edge_wgt)[0]
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, edge_wgt,
recursive)
cluster = cluster.to(src.device())
......
from typing import Any
import torch
try:
from typing_extensions import Final # noqa
......@@ -8,3 +9,9 @@ except ImportError:
def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)
def cartesian1d(x, y):
a1, a2 = torch.meshgrid([x, y])
coos = torch.stack([a1, a2]).T.reshape(-1, 2)
return coos.split(1, dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment