pagerank.py 594 Bytes
Newer Older
1
2
3
4
5
6
import networkx as nx
import torch
import dgl
import dgl.function as fn

N = 100
7
g = nx.erdos_renyi_graph(N, 0.05)
8
9
10
11
12
13
14
15
16
17
g = dgl.DGLGraph(g)

DAMP = 0.85
K = 10

def compute_pagerank(g):
    g.ndata['pv'] = torch.ones(N) / N
    degrees = g.out_degrees(g.nodes()).type(torch.float32)
    for k in range(K):
        g.ndata['pv'] = g.ndata['pv'] / degrees
18
        g.update_all(message_func=fn.copy_u(u='pv', out='m'),
19
20
21
22
23
24
                     reduce_func=fn.sum(msg='m', out='pv'))
        g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
    return g.ndata['pv']

pv = compute_pagerank(g)
print(pv)