metis.py 1.8 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Tuple

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
import time
import copy

import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data

rusty1s's avatar
rusty1s committed
11
12
partition_fn = torch.ops.torch_sparse.partition

rusty1s's avatar
rusty1s committed
13
14
15

def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
          log: bool = True) -> Tuple[Tensor, Tensor]:
rusty1s's avatar
rusty1s committed
16
17
18
    r"""Computes the METIS partition of a given sparse adjacency matrix
    :obj:`adj_t`, returning its "clustered" permutation :obj:`perm` and
    corresponding cluster slices :obj:`ptr`."""
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30

    if log:
        t = time.perf_counter()
        print(f'Computing METIS partitioning with {num_parts} parts...',
              end=' ', flush=True)

    num_nodes = adj_t.size(0)

    if num_parts <= 1:
        perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
    else:
        rowptr, col, _ = adj_t.csr()
rusty1s's avatar
rusty1s committed
31
        cluster = partition_fn(rowptr, col, None, num_parts, recursive)
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
40
        cluster, perm = cluster.sort()
        ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return perm, ptr


rusty1s's avatar
rusty1s committed
41
42
43
def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
    r"""Permutes a :obj:`data` object according to a given permutation
    :obj:`perm`."""
rusty1s's avatar
rusty1s committed
44
45
46
47
48

    if log:
        t = time.perf_counter()
        print('Permuting data...', end=' ', flush=True)

rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
    data = copy.copy(data)
    for key, value in data:
        if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
            data[key] = value[perm]
        elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
            raise NotImplementedError
        elif isinstance(value, SparseTensor):
            data[key] = value.permute(perm)
rusty1s's avatar
rusty1s committed
57
58
59
60
61

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return data