graclus.py 1.66 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
from typing import Optional
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
update  
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5


6
7
8
9
10
11
def graclus_cluster(
    row: torch.Tensor,
    col: torch.Tensor,
    weight: Optional[torch.Tensor] = None,
    num_nodes: Optional[int] = None,
) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
12
    """A greedy clustering algorithm of picking an unmarked vertex and matching
rusty1s's avatar
typo  
rusty1s committed
13
    it with one its unmarked neighbors (that maximizes its edge weight).
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        weight (Tensor, optional): Edge weights. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

rusty1s's avatar
docs  
rusty1s committed
21
22
    :rtype: :class:`LongTensor`

rusty1s's avatar
update  
rusty1s committed
23
24
25
26
    .. code-block:: python

        import torch
        from torch_cluster import graclus_cluster
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
update  
rusty1s committed
28
29
30
31
        row = torch.tensor([0, 1, 1, 2])
        col = torch.tensor([1, 0, 2, 1])
        weight = torch.Tensor([1, 1, 1, 1])
        cluster = graclus_cluster(row, col, weight)
rusty1s's avatar
rusty1s committed
32
    """
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
    if num_nodes is None:
rusty1s's avatar
update  
rusty1s committed
35
        num_nodes = max(int(row.max()), int(col.max())) + 1
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
    # Remove self-loops.
rusty1s's avatar
rusty1s committed
38
    mask = row != col
rusty1s's avatar
rusty1s committed
39
40
41
    row, col = row[mask], col[mask]

    if weight is not None:
rusty1s's avatar
rusty1s committed
42
43
44
45
46
        weight = weight[mask]

    # Randomly shuffle nodes.
    if weight is None:
        perm = torch.randperm(row.size(0), dtype=torch.long, device=row.device)
rusty1s's avatar
rusty1s committed
47
48
49
50
        row, col = row[perm], col[perm]

    # To CSR.
    perm = torch.argsort(row)
rusty1s's avatar
rusty1s committed
51
52
    row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
53
54
55
    if weight is not None:
        weight = weight[perm]

rusty1s's avatar
rusty1s committed
56
57
58
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
rusty1s's avatar
fix  
rusty1s committed
59
    torch.cumsum(deg, 0, out=rowptr[1:])
rusty1s's avatar
rusty1s committed
60
61

    return torch.ops.torch_cluster.graclus(rowptr, col, weight)