metis.py 615 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute


@torch.jit.script
def partition_kway(
        src: SparseTensor,
        num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:

    rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
    cluster = torch.ops.torch_sparse.partition_kway(rowptr, col, num_parts)
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm