pagerank.py 610 Bytes
Newer Older
1
2
import dgl
import dgl.function as fn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import networkx as nx
import torch
5
6

N = 100
7
g = nx.erdos_renyi_graph(N, 0.05)
8
9
10
11
12
g = dgl.DGLGraph(g)

DAMP = 0.85
K = 10

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13

14
def compute_pagerank(g):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
    g.ndata["pv"] = torch.ones(N) / N
16
17
    degrees = g.out_degrees(g.nodes()).type(torch.float32)
    for k in range(K):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
19
20
21
22
23
24
25
        g.ndata["pv"] = g.ndata["pv"] / degrees
        g.update_all(
            message_func=fn.copy_u(u="pv", out="m"),
            reduce_func=fn.sum(msg="m", out="pv"),
        )
        g.ndata["pv"] = (1 - DAMP) / N + DAMP * g.ndata["pv"]
    return g.ndata["pv"]

26
27
28

pv = compute_pagerank(g)
print(pv)