random.py 611 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
from .utils import get_func, consecutive
2
3
4
5
from .degree import node_degree
from .permute import permute


rusty1s's avatar
rusty1s committed
6
def random_cluster(edge_index, batch=None, num_nodes=None):
rusty1s's avatar
rusty1s committed
7

8
    num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
rusty1s's avatar
rusty1s committed
9
    row, col = permute(edge_index, num_nodes)
10
11
12
13
14
15
    degree = node_degree(row, num_nodes, out=row.new())

    cluster = edge_index.new(num_nodes).fill_(-1)
    func = get_func('random', cluster)
    func(cluster, row, col, degree)

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
    cluster, u = consecutive(cluster)

    if batch is None:
        return cluster
    else:
        # TODO: Fix
        return cluster, batch