metis.py 814 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute


rusty1s's avatar
rusty1s committed
8
9
10
def partition(
    src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
11
12

    rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
bwdeng20's avatar
bwdeng20 committed
13
14
    adjwgt=src.storage.value().cpu()
    cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,adjwgt,
rusty1s's avatar
rusty1s committed
15
                                               recursive)
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
rusty1s's avatar
rusty1s committed
23
24


rusty1s's avatar
rusty1s committed
25
26
SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
    self, num_parts, recursive)