graclus.py 1.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from __future__ import division

import torch

rusty1s's avatar
rusty1s committed
5
# from .utils import get_func
rusty1s's avatar
rusty1s committed
6
7
8
from .degree import node_degree


rusty1s's avatar
rusty1s committed
9
def graclus_cluster(edge_index, num_nodes=None, edge_attr=None, rid=None):
rusty1s's avatar
rusty1s committed
10
11
    num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
    rid = torch.randperm(num_nodes) if rid is None else rid
rusty1s's avatar
rusty1s committed
12
    row, col = edge_index
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
    # Compute edge-wise normalized cut.
rusty1s's avatar
rusty1s committed
15
16
    cut = normalized_cut(edge_index, num_nodes, edge_attr)

rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
    # Sort row and col indices based on the (possibly random) `rid`.
    _, perm = rid[row].sort()
    row, col, cut = row[perm], col[perm], cut[perm]
    print(row, col)

    cluster = edge_index.new(num_nodes).fill_(-1)
    # func = get_func('graclus', cluster)
    # func(cluster, row, col, cut)
    return cluster
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


def normalized_cut(edge_index, num_nodes, edge_attr=None):
    row, col = edge_index

    out = edge_attr.new() if edge_attr is not None else torch.Tensor()
    cut = node_degree(edge_index, num_nodes, out=out)
    cut = 1 / cut
    cut = cut[row] + cut[col]

    if edge_attr is None:
        return cut
    else:
        if edge_attr.dim() > 1 and edge_attr.size(1) > 1:
            edge_attr = torch.norm(edge_attr, 2, 1)
        return edge_attr.squeeze() * cut