metis.py 1.86 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from typing import Tuple

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
6
7


rusty1s's avatar
update  
rusty1s committed
8
9
10
11
12
13
def cartesian1d(x, y):
    a1, a2 = torch.meshgrid([x, y])
    coos = torch.stack([a1, a2]).T.reshape(-1, 2)
    return coos.split(1, dim=1)


14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def metis_weight1(x):
    sorted_x = x.sort()[0]
    diff = sorted_x[1:] - sorted_x[:-1]
    if diff.sum() == 0:
        return None
    xmin, xmax = sorted_x[[0, -1]]
    srange = xmax - xmin
    min_diff = diff.min()
    scale = (min_diff / srange).item()
    tick, arange = scale.as_integer_ratio()
    x_ratio = (x - xmin) / srange
    return (x_ratio * arange + tick).long()


def metis_weight2(x):
29
30
31
    t1, t2 = cartesian1d(x, x)
    diff = t1 - t2
    diff = diff[diff != 0]
bowendeng's avatar
bowendeng committed
32
    if len(diff) == 0:
33
        return None
34
35
36
37
    xmin, xmax = x.min(), x.max()
    srange = xmax - xmin
    min_diff = diff.abs().min()
    scale = (min_diff / srange).item()
38
    tick, arange = scale.as_integer_ratio()
39
    x_ratio = (x - xmin) / srange
bowendeng's avatar
bowendeng committed
40
    return (x_ratio * arange + tick).long()
rusty1s's avatar
rusty1s committed
41
42


43
44
45
46
47
def metis_weight(x, sort_strategy=True):
    return metis_weight1(x) if sort_strategy else metis_weight2(x)


def partition(src: SparseTensor, num_parts: int, recursive: bool = False, sort_strategy=True,
rusty1s's avatar
rusty1s committed
48
              ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rusty1s's avatar
update  
rusty1s committed
49
50
51
52
    rowptr, col, value = src.csr()
    rowptr, col = rowptr.cpu(), col.cpu()
    if value is not None and value.dim() == 1:
        value = value.detach().cpu()
53
        value = metis_weight(value, sort_strategy)
rusty1s's avatar
update  
rusty1s committed
54
    cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
rusty1s's avatar
rusty1s committed
55
                                               recursive)
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
rusty1s's avatar
rusty1s committed
63
64


rusty1s's avatar
rusty1s committed
65
SparseTensor.partition = partition