serial.py 659 Bytes
Newer Older
rusty1s's avatar
tests  
rusty1s committed
1
from .utils.permute import permute
rusty1s's avatar
rusty1s committed
2
from .utils.degree import node_degree
rusty1s's avatar
rusty1s committed
3
from .utils.ffi import _get_func
rusty1s's avatar
rusty1s committed
4
from .utils.consecutive import consecutive
5
6


rusty1s's avatar
rusty1s committed
7
def serial_cluster(edge_index, batch=None, num_nodes=None):
rusty1s's avatar
rusty1s committed
8

9
    num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
rusty1s's avatar
tests  
rusty1s committed
10
    row, col = permute(edge_index, num_nodes)
11
12
13
    degree = node_degree(row, num_nodes, out=row.new())

    cluster = edge_index.new(num_nodes).fill_(-1)
rusty1s's avatar
rusty1s committed
14
    func = _get_func('random', cluster)
15
16
    func(cluster, row, col, degree)

rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
    cluster, u = consecutive(cluster)

    if batch is None:
        return cluster
    else:
        # TODO: Fix
        return cluster, batch