graclus.py 1.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
# from .utils.loop import remove_self_loops
# from .utils.perm import randperm, sort_row, randperm_sort_row
# from .utils.ffi import graclus

import torch
import graclus_cpu

if torch.cuda.is_available():
    import graclus_cuda
rusty1s's avatar
rusty1s committed
10
11
12


def graclus_cluster(row, col, weight=None, num_nodes=None):
rusty1s's avatar
rusty1s committed
13
    """A greedy clustering algorithm of picking an unmarked vertex and matching
rusty1s's avatar
typo  
rusty1s committed
14
    it with one its unmarked neighbors (that maximizes its edge weight).
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        weight (Tensor, optional): Edge weights. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    Examples::

rusty1s's avatar
rusty1s committed
24
25
        >>> row = torch.tensor([0, 1, 1, 2])
        >>> col = torch.tensor([1, 0, 2, 1])
rusty1s's avatar
rusty1s committed
26
27
28
        >>> weight = torch.Tensor([1, 1, 1, 1])
        >>> cluster = graclus_cluster(row, col, weight)
    """
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
    if num_nodes is None:
        num_nodes = max(row.max().item(), col.max().item()) + 1
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
    op = graclus_cuda if row.is_cuda else graclus_cpu
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
36
37
38
    if weight is None:
        cluster = op.graclus(row, col, num_nodes)
    else:
        cluster = op.weighted_graclus(row, col, weight, num_nodes)
rusty1s's avatar
rusty1s committed
39
40

    return cluster
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46

    # if row.is_cuda:
    #     row, col = sort_row(row, col)
    # else:
    #     row, col = randperm(row, col)
    #     row, col = randperm_sort_row(row, col, num_nodes)