Unverified Commit afbfdc97 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #55 from bwdeng20/master

torch_sparse.partition supports weighted graphs
parents 056c0bab 7f8aac48
......@@ -6,3 +6,4 @@ dist/
*.egg-info/
.coverage
*.so
.idea/
......@@ -7,25 +7,33 @@
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
if (optional_value.has_value()) {
CHECK_CPU(optional_value.value());
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().numel() == col.numel());
}
int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t *adjwgt = NULL;
if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}
......
......@@ -3,4 +3,5 @@
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive);
......@@ -8,6 +8,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
......@@ -16,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, num_parts, recursive);
return partition_cpu(rowptr, col, optional_value, num_parts, recursive);
}
}
......
......@@ -7,11 +7,19 @@ from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_metis(device):
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = mat.partition(num_parts=2, recursive=False)
assert partptr.numel() == 3
assert perm.numel() == 6
value1 = torch.randn(6 * 6, device=device).view(6, 6)
value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
value3 = torch.ones(6 * 6, device=device).view(6, 6)
mat, partptr, perm = mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3
assert perm.numel() == 6
for value in [value1, value2, value3]:
mat = SparseTensor.from_dense(value)
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=True)
assert partptr.numel() == 3
assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6
from typing import Tuple
from typing import Tuple, Optional
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def partition(src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
sorted_weight = weight.sort()[0]
diff = sorted_weight[1:] - sorted_weight[:-1]
if diff.sum() == 0:
return None
weight_min, weight_max = sorted_weight[0], sorted_weight[-1]
srange = weight_max - weight_min
min_diff = diff.min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio()
weight_ratio = (weight - weight_min).div_(srange).mul_(arange).add_(tick)
return weight_ratio.to(torch.long)
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
weighted=False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu()
if value is not None and weighted:
assert value.numel() == col.numel()
value = value.view(-1).detach().cpu()
if value.is_floating_point():
value = weight2metis(value)
else:
value = None
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive)
cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment