metis.py 1.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch_sparse.utils import cartesian1d


def metis_wgt(x):
    t1, t2 = cartesian1d(x, x)
    diff = t1 - t2
    diff = diff[diff != 0]
    res = diff.abs().min()
    bod = x.max() - x.min()
    scale = (res / bod).item()
    tick, arange = scale.as_integer_ratio()
    x_ratio = (x - x.min()) / bod
    return (x_ratio * arange + tick).long(), tick, arange
rusty1s's avatar
rusty1s committed
19
20


rusty1s's avatar
rusty1s committed
21
def partition(
22
        src: SparseTensor, num_parts: int, recursive: bool = False
rusty1s's avatar
rusty1s committed
23
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
24
    rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
25
26
27
    edge_wgt = src.storage.value().cpu()
    edge_wgt = metis_wgt(edge_wgt)[0]
    cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, edge_wgt,
rusty1s's avatar
rusty1s committed
28
                                               recursive)
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
rusty1s's avatar
rusty1s committed
36
37


rusty1s's avatar
rusty1s committed
38
39
SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
    self, num_parts, recursive)