graclus.py 1.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .utils.loop import remove_self_loops
rusty1s's avatar
rusty1s committed
2
3
4
5
6
from .utils.perm import randperm, sort_row, randperm_sort_row
from .utils.ffi import graclus


def graclus_cluster(row, col, weight=None, num_nodes=None):
rusty1s's avatar
rusty1s committed
7
    """A greedy clustering algorithm of picking an unmarked vertex and matching
rusty1s's avatar
typo  
rusty1s committed
8
    it with one its unmarked neighbors (that maximizes its edge weight).
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        weight (Tensor, optional): Edge weights. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    Examples::

        >>> row = torch.LongTensor([0, 1, 1, 2])
        >>> col = torch.LongTensor([1, 0, 2, 1])
        >>> weight = torch.Tensor([1, 1, 1, 1])
        >>> cluster = graclus_cluster(row, col, weight)
    """
rusty1s's avatar
rusty1s committed
23
24

    num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
    if row.is_cuda:  # pragma: no cover
rusty1s's avatar
rusty1s committed
27
28
        row, col = sort_row(row, col)
    else:
rusty1s's avatar
rusty1s committed
29
        row, col = randperm(row, col)
rusty1s's avatar
rusty1s committed
30
31
        row, col = randperm_sort_row(row, col, num_nodes)

rusty1s's avatar
rusty1s committed
32
    row, col = remove_self_loops(row, col)
rusty1s's avatar
rusty1s committed
33
    cluster = row.new_empty((num_nodes, ))
rusty1s's avatar
rusty1s committed
34
35
36
    graclus(cluster, row, col, weight)

    return cluster