metis.py 1.49 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Tuple, Optional
rusty1s's avatar
rusty1s committed
2
3
4
5

import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
    sorted_weight = weight.sort()[0]
    diff = sorted_weight[1:] - sorted_weight[:-1]
11
12
    if diff.sum() == 0:
        return None
rusty1s's avatar
rusty1s committed
13
14
    weight_min, weight_max = sorted_weight[0], sorted_weight[-1]
    srange = weight_max - weight_min
15
16
17
    min_diff = diff.min()
    scale = (min_diff / srange).item()
    tick, arange = scale.as_integer_ratio()
rusty1s's avatar
rusty1s committed
18
19
    weight_ratio = (weight - weight_min).div_(srange).mul_(arange).add_(tick)
    return weight_ratio.to(torch.long)
20

rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
              weighted=False
rusty1s's avatar
rusty1s committed
24
              ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rusty1s's avatar
update  
rusty1s committed
25
26
    rowptr, col, value = src.csr()
    rowptr, col = rowptr.cpu(), col.cpu()
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35

    if value is not None and weighted:
        assert value.numel() == col.numel()
        value = value.view(-1).detach().cpu()
        if value.is_floating_point():
            value = weight2metis(value)
    else:
        value = None

rusty1s's avatar
update  
rusty1s committed
36
    cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
rusty1s's avatar
rusty1s committed
37
                                               recursive)
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
SparseTensor.partition = partition