graclus.py 1.7 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Optional

import torch


@torch.jit.script
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
                    weight: Optional[torch.Tensor] = None,
                    num_nodes: Optional[int] = None) -> torch.Tensor:
    """A greedy clustering algorithm of picking an unmarked vertex and matching
    it with one its unmarked neighbors (that maximizes its edge weight).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        weight (Tensor, optional): Edge weights. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_cluster import graclus_cluster

        row = torch.tensor([0, 1, 1, 2])
        col = torch.tensor([1, 0, 2, 1])
        weight = torch.Tensor([1, 1, 1, 1])
        cluster = graclus_cluster(row, col, weight)
    """

    if num_nodes is None:
        num_nodes = max(int(row.max()), int(col.max())) + 1

    # Remove self-loops.
    mask = row != col
    row, col = row[mask], col[mask]

    if weight is not None:
        weight = weight[mask]

    # Randomly shuffle nodes.
    if weight is None:
        perm = torch.randperm(row.size(0), dtype=torch.long, device=row.device)
        row, col = row[perm], col[perm]

    # To CSR.
    perm = torch.argsort(row)
    row, col = row[perm], col[perm]

    if weight is not None:
        weight = weight[perm]

    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
    torch.cumsum(deg, 0, out=rowptr[1:])

    return torch.ops.torch_cluster.graclus(rowptr, col, weight)