3_pagerank.py 10.3 KB
Newer Older
1
2
3
"""
.. currentmodule:: dgl

4
PageRank with DGL message passing
5
6
7
8
9
=================================

**Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, Yu Gai,
Zheng Zhang

10
In this tutorial, you learn how to use different levels of the message
11
passing API with PageRank on a small graph. In DGL, the message passing and
12
feature transformations are **user-defined functions** (UDFs).
13
14
15
16

"""

###############################################################################
17
# The PageRank algorithm
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# ----------------------
# In each iteration of PageRank, every node (web page) first scatters its
# PageRank value uniformly to its downstream nodes. The new PageRank value of
# each node is computed by aggregating the received PageRank values from its
# neighbors, which is then adjusted by the damping factor:
#
# .. math::
#
#    PV(u) = \frac{1-d}{N} + d \times \sum_{v \in \mathcal{N}(u)}
#    \frac{PV(v)}{D(v)}
#
# where :math:`N` is the number of nodes in the graph; :math:`D(v)` is the
# out-degree of a node :math:`v`; and :math:`\mathcal{N}(u)` is the neighbor
# nodes.


###############################################################################
# A naive implementation
# ----------------------
37
38
# Create a graph with 100 nodes by using ``networkx`` and then convert it to a
# :class:`DGLGraph`.
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dgl

N = 100  # number of nodes
DAMP = 0.85  # damping factor
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1)
g = dgl.DGLGraph(g)
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show()


###############################################################################
# According to the algorithm, PageRank consists of two phases in a typical
56
57
# scatter-gather pattern. Initialize the PageRank value of each node
# to :math:`\frac{1}{N}` and then store each node's out-degree as a node feature.
58
59
60
61
62
63

g.ndata['pv'] = torch.ones(N) / N
g.ndata['deg'] = g.out_degrees(g.nodes()).float()


###############################################################################
64
65
# Define the message function, which divides every node's PageRank
# value by its out-degree and passes the result as message to its neighbors.
66
67
68
69
70
71
72
73
74

def pagerank_message_func(edges):
    return {'pv' : edges.src['pv'] / edges.src['deg']}


###############################################################################
# In DGL, the message functions are expressed as **Edge UDFs**.  Edge UDFs
# take in a single argument ``edges``.  It has three members ``src``, ``dst``,
# and ``data`` for accessing source node features, destination node features,
75
# and edge features.  Here, the function computes messages only
76
77
# from source node features.
#
78
79
# Define the reduce function, which removes and aggregates the
# messages from its ``mailbox``, and computes its new PageRank value.
80
81
82
83
84
85
86
87
88
89

def pagerank_reduce_func(nodes):
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {'pv' : pv}


###############################################################################
# The reduce functions are **Node UDFs**.  Node UDFs have a single argument
# ``nodes``, which has two members ``data`` and ``mailbox``.  ``data``
90
# contains the node features and ``mailbox`` contains all incoming message
91
92
93
94
95
96
97
98
# features, stacked along the second dimension (hence the ``dim=1`` argument).
#
# The message UDF works on a batch of edges, whereas the reduce UDF works on
# a batch of edges but outputs a batch of nodes. Their relationships are as
# follows:
#
# .. image:: https://i.imgur.com/kIMiuFb.png
#
99
# Register the message function and reduce function, which will be called
100
101
102
103
104
105
106
# later by DGL.

g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func)


###############################################################################
107
108
# The algorithm is straightforward. Here is the code for one
# PageRank iteration.
109
110
111
112
113
114
115
116
117
118
119

def pagerank_naive(g):
    # Phase #1: send out messages along all edges.
    for u, v in zip(*g.edges()):
        g.send((u, v))
    # Phase #2: receive messages to compute new PageRank values.
    for v in g.nodes():
        g.recv(v)


###############################################################################
120
# Batching semantics for a large graph
121
# ------------------------------------
122
123
# The above code does not scale to a large graph because it iterates over all
# the nodes. DGL solves this by allowing you to compute on a *batch* of nodes or
124
# edges. For example, the following codes trigger message and reduce functions
125
# on multiple nodes and edges at one time.
126
127
128
129
130
131
132

def pagerank_batch(g):
    g.send(g.edges())
    g.recv(g.nodes())


###############################################################################
133
# You are still using the same reduce function ``pagerank_reduce_func``,
134
135
136
# where ``nodes.mailbox['pv']`` is a *single* tensor, stacking the incoming
# messages along the second dimension.
#
137
# You might wonder if this is even possible to perform reduce on all
138
# nodes in parallel, since each node may have different number of incoming
139
# messages and you cannot really "stack" tensors of different lengths together.
140
141
142
143
144
# In general, DGL solves the problem by grouping the nodes by the number of
# incoming messages, and calling the reduce function for each group.


###############################################################################
145
# Use higher-level APIs for efficiency
146
# ---------------------------------------
147
148
149
# DGL provides many routines that combine basic ``send`` and ``recv`` in
# various ways. These routines are called **level-2 APIs**. For example, the next code example
# shows how to further simplify the PageRank example with such an API.
150
151
152
153
154
155

def pagerank_level2(g):
    g.update_all()


###############################################################################
156
157
# In addition to ``update_all``, you can use ``pull``, ``push``, and ``send_and_recv``
# in this level-2 category. For more information, see :doc:`API reference <../../api/python/graph>`.
158
159
160


###############################################################################
161
# Use DGL ``builtin`` functions for efficiency
162
# ------------------------------------------------
163
164
# Some of the message and reduce functions are used frequently. For this reason, DGL also
# provides ``builtin`` functions. For example, two ``builtin`` functions can be
165
166
# used in the PageRank example.
#
167
168
169
# * :func:`dgl.function.copy_src(src, out) <function.copy_src>` - This
#   code example is an edge UDF that computes the
#   output using the source node feature data. To use this, specify the name of
170
171
#   the source feature data (``src``) and the output name (``out``).
# 
172
# * :func:`dgl.function.sum(msg, out) <function.sum>` - This code example is a node UDF
173
#   that sums the messages in
174
#   the node's mailbox. To use this, specify the message name (``msg``) and the
175
176
#   output name (``out``).
#
177
# The following PageRank example shows such functions.
178
179
180
181
182
183
184
185
186
187
188

import dgl.function as fn

def pagerank_builtin(g):
    g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
    g.update_all(message_func=fn.copy_src(src='pv', out='m'),
                 reduce_func=fn.sum(msg='m',out='m_sum'))
    g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']


###############################################################################
189
# In the previous example code, you directly provide the UDFs to the :func:`update_all <DGLGraph.update_all>`
190
191
192
# as its arguments.
# This will override the previously registered UDFs.
#
193
194
# In addition to cleaner code, using ``builtin`` functions also gives DGL the
# opportunity to fuse operations together. This results in faster execution.  For
195
196
197
# example, DGL will fuse the ``copy_src`` message function and ``sum`` reduce
# function into one sparse matrix-vector (spMV) multiplication.
#
198
199
200
# `The following section <spmv_>`_ describes why spMV can speed up the scatter-gather
# phase in PageRank.  For more details about the ``builtin`` functions in DGL,
# see :doc:`API reference <../../api/python/function>`.
201
#
202
# You can also download and run the different code examples to see the differences.
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

for k in range(K):
    # Uncomment the corresponding line to select different version.
    # pagerank_naive(g)
    # pagerank_batch(g)
    # pagerank_level2(g)
    pagerank_builtin(g)
print(g.ndata['pv'])


###############################################################################
# .. _spmv:
#
# Using spMV for PageRank
# -----------------------
218
219
220
# Using ``builtin`` functions allows DGL to understand the semantics of UDFs.
# This allows you to create an efficient implementation. For example, in the case
# of PageRank, one common method to accelerate it is by using its linear algebra
221
222
223
224
225
226
227
228
229
# form.
#
# .. math::
#
#    \mathbf{R}^{k} = \frac{1-d}{N} \mathbf{1} + d \mathbf{A}*\mathbf{R}^{k-1}
#
# Here, :math:`\mathbf{R}^k` is the vector of the PageRank values of all nodes
# at iteration :math:`k`; :math:`\mathbf{A}` is the sparse adjacency matrix
# of the graph.
230
231
232
233
234
235
# Computing this equation is quite efficient because there is an efficient
# GPU kernel for the sparse matrix-vector multiplication (spMV). DGL
# detects whether such optimization is available through the ``builtin``
# functions. If a certain combination of ``builtin`` can be mapped to an spMV
# kernel (e.g., the PageRank example), DGL uses it automatically. We recommend 
# using ``builtin`` functions whenever possible.
236
237
238
239
240


###############################################################################
# Next steps
# ----------
241
# 
242
243
244
245
246
247
248
249
# * Learn how to use DGL (:doc:`builtin functions<../../features/builtin>`) to write 
#   more efficient message passing.
# * To see model tutorials, see the :doc:`overview page<../models/index>`.
# * To learn about Graph Neural Networks, see :doc:`GCN tutorial<../models/1_gnn/1_gcn>`.
# * To see how DGL batches multiple graphs, see :doc:`TreeLSTM tutorial<../models/2_small_graph/3_tree-lstm>`.
# * Play with some graph generative models by following tutorial for :doc:`Deep Generative Model of Graphs<../models/3_generative_model/5_dgmg>`.
# * To learn how traditional models are interpreted in a view of graph, see 
#   the tutorials on :doc:`CapsuleNet<../models/4_old_wines/2_capsule>` and
250
#   :doc:`Transformer<../models/4_old_wines/7_transformer>`.