3_message_passing.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
Write your own GNN module
=========================

Sometimes, your model goes beyond simply stacking existing GNN modules.
For example, you would like to invent a new way of aggregating neighbor
information by considering node importance or edge weights.

By the end of this tutorial you will be able to

-  Understand DGL’s message passing APIs.
-  Implement GraphSAGE convolution module by your own.

This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>`.

(Time estimate: 10 minutes)

"""

21
22
import os
os.environ['DGLBACKEND'] = 'pytorch'
23
24
25
26
import torch
import torch.nn as nn
import torch.nn.functional as F

27
28
import dgl
import dgl.function as fn
29
30
31
32

######################################################################
# Message passing and GNNs
# ------------------------
33
#
34
35
36
37
# DGL follows the *message passing paradigm* inspired by the Message
# Passing Neural Network proposed by `Gilmer et
# al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many
# GNN models can fit into the following framework:
38
#
39
# .. math::
40
41
#
#
42
#    m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)
43
#
44
# .. math::
45
46
#
#
47
#    m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}
48
#
49
# .. math::
50
51
#
#
52
#    h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
53
#
54
55
56
57
# where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the
# *reduce function* and :math:`U^{(l)}` the *update function*. Note that
# :math:`\sum` here can represent any function and is not necessarily a
# summation.
58
#
59
60
61
62
63
64


######################################################################
# For example, the `GraphSAGE convolution (Hamilton et al.,
# 2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__
# takes the following mathematical form:
65
#
66
# .. math::
67
68
#
#
69
#    h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}
70
#
71
# .. math::
72
73
#
#
74
#    h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)
75
#
76
77
78
79
# You can see that message passing is directional: the message sent from
# one node :math:`u` to other node :math:`v` is not necessarily the same
# as the other message sent from node :math:`v` to node :math:`u` in the
# opposite direction.
80
#
81
# Although DGL has builtin support of GraphSAGE via
82
# :class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,
83
# here is how you can implement GraphSAGE convolution in DGL by your own.
84
#
85
86
87
88


class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
89

90
91
92
93
94
95
96
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
97

98
99
100
101
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
102

103
104
    def forward(self, g, h):
        """Forward computation
105

106
107
108
109
110
111
112
113
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
114
            g.ndata["h"] = h
115
            # update_all is a message passing API.
116
117
118
119
120
            g.update_all(
                message_func=fn.copy_u("h", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
121
122
123
124
125
126
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


######################################################################
# The central piece in this code is the
127
# :func:`g.update_all <dgl.DGLGraph.update_all>`
128
129
130
131
# function, which gathers and averages the neighbor features. There are
# three concepts here:
#
# * Message function ``fn.copy_u('h', 'm')`` that
132
133
#   copies the node feature under name ``'h'`` as *messages* with name
#   ``'m'`` sent to neighbors.
134
135
136
137
138
139
140
#
# * Reduce function ``fn.mean('m', 'h_N')`` that averages
#   all the received messages under name ``'m'`` and saves the result as a
#   new node feature ``'h_N'``.
#
# * ``update_all`` tells DGL to trigger the
#   message and reduce functions for all the nodes and edges.
141
#
142
143
144
145
146
147
148


######################################################################
# Afterwards, you can stack your own GraphSAGE convolution layers to form
# a multi-layer GraphSAGE network.
#

149

150
151
152
153
154
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
155

156
157
158
159
160
161
162
163
164
165
166
167
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


######################################################################
# Training loop
# ~~~~~~~~~~~~~
# The following code for data loading and training loop is directly copied
# from the introduction tutorial.
168
#
169
170
171
172
173
174

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

175

176
177
178
179
180
181
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

182
183
184
185
186
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
216
217
218
219
220
221
            print(
                "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                )
            )

222

223
model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
224
225
226
227
228
229
train(g, model)


######################################################################
# More customization
# ------------------
230
#
231
232
233
# In DGL, we provide many built-in message and reduce functions under the
# ``dgl.function`` package. You can find more details in :ref:`the API
# doc <apifunction>`.
234
#
235
236
237
238
239
240
241
242


######################################################################
# These APIs allow one to quickly implement new graph convolution modules.
# For example, the following implements a new ``SAGEConv`` that aggregates
# neighbor representations using a weighted average. Note that ``edata``
# member can hold edge features which can also take part in message
# passing.
243
244
#

245
246
247

class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.
248

249
250
251
252
253
254
255
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
256

257
258
259
260
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
261

262
263
    def forward(self, g, h, w):
        """Forward computation
264

265
266
267
268
269
270
271
272
273
274
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
275
276
277
278
279
280
281
            g.ndata["h"] = h
            g.edata["w"] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
282
283
284
285
286
287
288
289
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


######################################################################
# Because the graph in this dataset does not have edge weights, we
# manually assign all edge weights to one in the ``forward()`` function of
# the model. You can replace it with your own edge weights.
290
291
#

292
293
294
295
296
297

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
298

299
    def forward(self, g, in_feat):
300
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
301
        h = F.relu(h)
302
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
303
        return h
304
305
306


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
307
308
309
310
311
312
train(g, model)


######################################################################
# Even more customization by user-defined function
# ------------------------------------------------
313
#
314
315
316
# DGL allows user-defined message and reduce function for the maximal
# expressiveness. Here is a user-defined message function that is
# equivalent to ``fn.u_mul_e('h', 'w', 'm')``.
317
318
#

319
320

def u_mul_e_udf(edges):
321
    return {"m": edges.src["h"] * edges.data["w"]}
322
323
324
325
326
327


######################################################################
# ``edges`` has three members: ``src``, ``data`` and ``dst``, representing
# the source node feature, edge feature, and destination node feature for
# all edges.
328
#
329
330
331
332


######################################################################
# You can also write your own reduce function. For example, the following
333
# is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
334
# the incoming messages:
335
336
#

337

338
def mean_udf(nodes):
339
    return {"h_N": nodes.mailbox["m"].mean(1)}
340
341
342
343


######################################################################
# In short, DGL will group the nodes by their in-degrees, and for each
344
# group DGL stacks the incoming messages along the second dimension. You
345
346
# can then perform a reduction along the second dimension to aggregate
# messages.
347
#
348
349
350
# For more details on customizing message and reduce function with
# user-defined function, please refer to the :ref:`API
# reference <apiudf>`.
351
#
352
353
354
355
356


######################################################################
# Best practice of writing custom GNN modules
# -------------------------------------------
357
#
358
# DGL recommends the following practice ranked by preference:
359
#
360
361
362
363
364
365
# -  Use ``dgl.nn`` modules.
# -  Use ``dgl.nn.functional`` functions which contain lower-level complex
#    operations such as computing a softmax for each node over incoming
#    edges.
# -  Use ``update_all`` with builtin message and reduce functions.
# -  Use user-defined message or reduce functions.
366
#
367
368
369
370
371


######################################################################
# What’s next?
# ------------
372
#
373
374
# -  :ref:`Writing Efficient Message Passing
#    Code <guide-message-passing-efficient>`.
375
#
376

377

378
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
379
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'