random.py 731 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
from .utils import get_func, consecutive
2
3
4
5
from .degree import node_degree
from .permute import permute


rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
def random_cluster(edge_index,
                   batch=None,
                   rid=None,
                   perm_edges=True,
                   num_nodes=None):

12
    num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
rusty1s's avatar
rusty1s committed
13
    row, col = permute(edge_index, num_nodes, rid, perm_edges)
14
15
16
17
18
19
    degree = node_degree(row, num_nodes, out=row.new())

    cluster = edge_index.new(num_nodes).fill_(-1)
    func = get_func('random', cluster)
    func(cluster, row, col, degree)

rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
    cluster, u = consecutive(cluster)

    if batch is None:
        return cluster
    else:
        # TODO: Fix
        return cluster, batch