"llm/vscode:/vscode.git/clone" did not exist on "58d95cc9bd446a8209e7388a96c70367cbafd653"
metis.py 1.16 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
6
7
8
9
from torch_sparse.utils import cartesian1d


def metis_wgt(x):
bowendeng's avatar
bowendeng committed
10
11
    if len(x.unique()) == 1:
        return None
12
13
14
15
16
17
18
19
    t1, t2 = cartesian1d(x, x)
    diff = t1 - t2
    diff = diff[diff != 0]
    res = diff.abs().min()
    bod = x.max() - x.min()
    scale = (res / bod).item()
    tick, arange = scale.as_integer_ratio()
    x_ratio = (x - x.min()) / bod
bowendeng's avatar
bowendeng committed
20
    return (x_ratio * arange + tick).long()
rusty1s's avatar
rusty1s committed
21
22


rusty1s's avatar
rusty1s committed
23
24
def partition(src: SparseTensor, num_parts: int, recursive: bool = False
              ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
25
    rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
26
    edge_wgt = src.storage.value().cpu()
bowendeng's avatar
bowendeng committed
27
    edge_wgt = metis_wgt(edge_wgt)
28
    cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts, edge_wgt,
rusty1s's avatar
rusty1s committed
29
                                               recursive)
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
rusty1s's avatar
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
SparseTensor.partition = partition