graclus.py 1.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .utils.loop import remove_self_loops
rusty1s's avatar
rusty1s committed
2
3
4
5
6
from .utils.perm import randperm, sort_row, randperm_sort_row
from .utils.ffi import graclus


def graclus_cluster(row, col, weight=None, num_nodes=None):
rusty1s's avatar
rusty1s committed
7
    """A greedy clustering algorithm of picking an unmarked vertex and matching
rusty1s's avatar
typo  
rusty1s committed
8
    it with one its unmarked neighbors (that maximizes its edge weight).
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        weight (Tensor, optional): Edge weights. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    Examples::

        >>> row = torch.LongTensor([0, 1, 1, 2])
        >>> col = torch.LongTensor([1, 0, 2, 1])
        >>> weight = torch.Tensor([1, 1, 1, 1])
        >>> cluster = graclus_cluster(row, col, weight)
    """
rusty1s's avatar
rusty1s committed
23
24
    num_nodes = row.max() + 1 if num_nodes is None else num_nodes

rusty1s's avatar
rusty1s committed
25
    if row.is_cuda:  # pragma: no cover
rusty1s's avatar
rusty1s committed
26
27
        row, col = sort_row(row, col)
    else:
rusty1s's avatar
rusty1s committed
28
        row, col = randperm(row, col)
rusty1s's avatar
rusty1s committed
29
30
        row, col = randperm_sort_row(row, col, num_nodes)

rusty1s's avatar
rusty1s committed
31
    row, col = remove_self_loops(row, col)
rusty1s's avatar
rusty1s committed
32
33
34
35
    cluster = row.new(num_nodes)
    graclus(cluster, row, col, weight)

    return cluster