Unverified Commit 4d49d44e authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #123 from Spazierganger/vwgts

add node weights for metis wrapper
parents 54d8418e cb1e30da
...@@ -12,32 +12,46 @@ ...@@ -12,32 +12,46 @@
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) { int64_t num_parts, bool recursive) {
#ifdef WITH_METIS #ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
if (optional_value.has_value()) { if (optional_value.has_value()) {
CHECK_CPU(optional_value.value()); CHECK_CPU(optional_value.value());
CHECK_INPUT(optional_value.value().dim() == 1); CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().numel() == col.numel()); CHECK_INPUT(optional_value.value().numel() == col.numel());
} }
if (optional_node_weight.has_value()) {
CHECK_CPU(optional_node_weight.value());
CHECK_INPUT(optional_node_weight.value().dim() == 1);
CHECK_INPUT(optional_node_weight.value().numel() == rowptr.numel() - 1);
}
int64_t nvtxs = rowptr.numel() - 1; int64_t nvtxs = rowptr.numel() - 1;
int64_t ncon = 1; int64_t ncon = 1;
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
int64_t *adjwgt = NULL; int64_t *adjwgt = NULL;
if (optional_value.has_value()) if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>(); adjwgt = optional_value.value().data_ptr<int64_t>();
int64_t *vwgt = NULL;
if (optional_node_weight.has_value())
vwgt = optional_node_weight.value().data_ptr<int64_t>();
int64_t objval = -1; int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
if (recursive) { if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt, METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data); &num_parts, NULL, NULL, NULL, &objval, part_data);
} else { } else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt, METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&num_parts, NULL, NULL, NULL, &objval, part_data); &num_parts, NULL, NULL, NULL, &objval, part_data);
} }
...@@ -50,10 +64,11 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -50,10 +64,11 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// needs mt-metis installed via: // needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit // ./configure --shared --edges64bit --vertices64bit --weights64bit
// --partitions64bit // --partitions64bit
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor
mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive, torch::optional<torch::Tensor> optional_node_weight,
int64_t num_workers) { int64_t num_parts, bool recursive, int64_t num_workers) {
#ifdef WITH_MTMETIS #ifdef WITH_MTMETIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
...@@ -63,13 +78,25 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -63,13 +78,25 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT(optional_value.value().numel() == col.numel()); CHECK_INPUT(optional_value.value().numel() == col.numel());
} }
if (optional_node_weight.has_value()) {
CHECK_CPU(optional_node_weight.value());
CHECK_INPUT(optional_node_weight.value().dim() == 1);
CHECK_INPUT(optional_node_weight.value().numel() == rowptr.numel() - 1);
}
mtmetis_vtx_type nvtxs = rowptr.numel() - 1; mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
mtmetis_vtx_type ncon = 1; mtmetis_vtx_type ncon = 1;
mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>(); mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>();
mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>(); mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>();
mtmetis_wgt_type *adjwgt = NULL; mtmetis_wgt_type *adjwgt = NULL;
if (optional_value.has_value()) if (optional_value.has_value())
adjwgt = optional_value.value().data_ptr<int64_t>(); adjwgt = optional_value.value().data_ptr<int64_t>();
mtmetis_wgt_type *vwgt = NULL;
if (optional_node_weight.has_value())
vwgt = optional_node_weight.value().data_ptr<int64_t>();
mtmetis_pid_type nparts = num_parts; mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1; mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty(nvtxs, rowptr.options());
...@@ -79,10 +106,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -79,10 +106,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
opts[MTMETIS_OPTION_NTHREADS] = num_workers; opts[MTMETIS_OPTION_NTHREADS] = num_workers;
if (recursive) { if (recursive) {
MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt, MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data); &nparts, NULL, NULL, opts, &objval, part_data);
} else { } else {
MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt, MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, opts, &objval, part_data); &nparts, NULL, NULL, opts, &objval, part_data);
} }
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive); int64_t num_parts, bool recursive);
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor
mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
int64_t num_parts, bool recursive, torch::optional<torch::Tensor> optional_node_weight,
int64_t num_workers); int64_t num_parts, bool recursive, int64_t num_workers);
...@@ -13,6 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; } ...@@ -13,6 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive) { int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
...@@ -21,12 +22,14 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col, ...@@ -21,12 +22,14 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return partition_cpu(rowptr, col, optional_value, num_parts, recursive); return partition_cpu(rowptr, col, optional_value, optional_node_weight,
num_parts, recursive);
} }
} }
torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col, torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::optional<torch::Tensor> optional_value,
torch::optional<torch::Tensor> optional_node_weight,
int64_t num_parts, bool recursive, int64_t num_parts, bool recursive,
int64_t num_workers) { int64_t num_workers) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
...@@ -36,8 +39,8 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col, ...@@ -36,8 +39,8 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
#endif #endif
} else { } else {
return mt_partition_cpu(rowptr, col, optional_value, num_parts, recursive, return mt_partition_cpu(rowptr, col, optional_value, optional_node_weight,
num_workers); num_parts, recursive, num_workers);
} }
} }
......
import pytest import pytest
from itertools import product
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
...@@ -12,26 +14,24 @@ except RuntimeError: ...@@ -12,26 +14,24 @@ except RuntimeError:
@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support') @pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support')
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device,weighted', product(devices, [False, True]))
def test_metis(device): def test_metis(device, weighted):
value1 = torch.randn(6 * 6, device=device).view(6, 6) mat1 = torch.randn(6 * 6, device=device).view(6, 6)
value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6) mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
value3 = torch.ones(6 * 6, device=device).view(6, 6) mat3 = torch.ones(6 * 6, device=device).view(6, 6)
for value in [value1, value2, value3]: vec1 = None
mat = SparseTensor.from_dense(value) vec2 = torch.rand(6, device=device)
_, partptr, perm = mat.partition(num_parts=2, recursive=False, for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
weighted=True) mat = SparseTensor.from_dense(mat)
assert partptr.numel() == 3
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=weighted, node_weight=vec)
assert partptr.numel() == 2
assert perm.numel() == 6 assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=2, recursive=False, _, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=False) weighted=weighted, node_weight=vec)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
_, partptr, perm = mat.partition(num_parts=1, recursive=False,
weighted=True)
assert partptr.numel() == 2
assert perm.numel() == 6
...@@ -21,7 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]: ...@@ -21,7 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def partition( def partition(
src: SparseTensor, num_parts: int, recursive: bool = False, src: SparseTensor, num_parts: int, recursive: bool = False,
weighted: bool = False weighted: bool = False, node_weight: Optional[torch.Tensor] = None
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
assert num_parts >= 1 assert num_parts >= 1
...@@ -41,8 +41,14 @@ def partition( ...@@ -41,8 +41,14 @@ def partition(
else: else:
value = None value = None
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, if node_weight is not None:
recursive) assert node_weight.numel() == rowptr.numel() - 1
node_weight = node_weight.view(-1).detach().cpu()
if node_weight.is_floating_point():
node_weight = weight2metis(node_weight)
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, node_weight,
num_parts, recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
cluster, perm = cluster.sort() cluster, perm = cluster.sort()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment