3_message_passing.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""
Write your own GNN module
=========================

Sometimes, your model goes beyond simply stacking existing GNN modules.
For example, you would like to invent a new way of aggregating neighbor
information by considering node importance or edge weights.

By the end of this tutorial you will be able to

-  Understand DGL’s message passing APIs.
-  Implement GraphSAGE convolution module by your own.

This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>`.

(Time estimate: 10 minutes)

"""

import torch
import torch.nn as nn
import torch.nn.functional as F

25
26
import dgl
import dgl.function as fn
27
28
29
30

######################################################################
# Message passing and GNNs
# ------------------------
31
#
32
33
34
35
# DGL follows the *message passing paradigm* inspired by the Message
# Passing Neural Network proposed by `Gilmer et
# al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many
# GNN models can fit into the following framework:
36
#
37
# .. math::
38
39
#
#
40
#    m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)
41
#
42
# .. math::
43
44
#
#
45
#    m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}
46
#
47
# .. math::
48
49
#
#
50
#    h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
51
#
52
53
54
55
# where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the
# *reduce function* and :math:`U^{(l)}` the *update function*. Note that
# :math:`\sum` here can represent any function and is not necessarily a
# summation.
56
#
57
58
59
60
61
62


######################################################################
# For example, the `GraphSAGE convolution (Hamilton et al.,
# 2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__
# takes the following mathematical form:
63
#
64
# .. math::
65
66
#
#
67
#    h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}
68
#
69
# .. math::
70
71
#
#
72
#    h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)
73
#
74
75
76
77
# You can see that message passing is directional: the message sent from
# one node :math:`u` to other node :math:`v` is not necessarily the same
# as the other message sent from node :math:`v` to node :math:`u` in the
# opposite direction.
78
#
79
# Although DGL has builtin support of GraphSAGE via
80
# :class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,
81
# here is how you can implement GraphSAGE convolution in DGL by your own.
82
#
83
84
85
86


class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
87

88
89
90
91
92
93
94
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
95

96
97
98
99
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
100

101
102
    def forward(self, g, h):
        """Forward computation
103

104
105
106
107
108
109
110
111
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
112
            g.ndata["h"] = h
113
            # update_all is a message passing API.
114
115
116
117
118
            g.update_all(
                message_func=fn.copy_u("h", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
119
120
121
122
123
124
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


######################################################################
# The central piece in this code is the
125
# :func:`g.update_all <dgl.DGLGraph.update_all>`
126
127
128
129
# function, which gathers and averages the neighbor features. There are
# three concepts here:
#
# * Message function ``fn.copy_u('h', 'm')`` that
130
131
#   copies the node feature under name ``'h'`` as *messages* with name
#   ``'m'`` sent to neighbors.
132
133
134
135
136
137
138
#
# * Reduce function ``fn.mean('m', 'h_N')`` that averages
#   all the received messages under name ``'m'`` and saves the result as a
#   new node feature ``'h_N'``.
#
# * ``update_all`` tells DGL to trigger the
#   message and reduce functions for all the nodes and edges.
139
#
140
141
142
143
144
145
146


######################################################################
# Afterwards, you can stack your own GraphSAGE convolution layers to form
# a multi-layer GraphSAGE network.
#

147

148
149
150
151
152
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
153

154
155
156
157
158
159
160
161
162
163
164
165
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


######################################################################
# Training loop
# ~~~~~~~~~~~~~
# The following code for data loading and training loop is directly copied
# from the introduction tutorial.
166
#
167
168
169
170
171
172

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

173

174
175
176
177
178
179
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

180
181
182
183
184
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
214
215
216
217
218
219
            print(
                "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                )
            )

220

221
model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
222
223
224
225
226
227
train(g, model)


######################################################################
# More customization
# ------------------
228
#
229
230
231
# In DGL, we provide many built-in message and reduce functions under the
# ``dgl.function`` package. You can find more details in :ref:`the API
# doc <apifunction>`.
232
#
233
234
235
236
237
238
239
240


######################################################################
# These APIs allow one to quickly implement new graph convolution modules.
# For example, the following implements a new ``SAGEConv`` that aggregates
# neighbor representations using a weighted average. Note that ``edata``
# member can hold edge features which can also take part in message
# passing.
241
242
#

243
244
245

class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.
246

247
248
249
250
251
252
253
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
254

255
256
257
258
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
259

260
261
    def forward(self, g, h, w):
        """Forward computation
262

263
264
265
266
267
268
269
270
271
272
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
273
274
275
276
277
278
279
            g.ndata["h"] = h
            g.edata["w"] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
280
281
282
283
284
285
286
287
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


######################################################################
# Because the graph in this dataset does not have edge weights, we
# manually assign all edge weights to one in the ``forward()`` function of
# the model. You can replace it with your own edge weights.
288
289
#

290
291
292
293
294
295

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
296

297
    def forward(self, g, in_feat):
298
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
299
        h = F.relu(h)
300
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
301
        return h
302
303
304


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
305
306
307
308
309
310
train(g, model)


######################################################################
# Even more customization by user-defined function
# ------------------------------------------------
311
#
312
313
314
# DGL allows user-defined message and reduce function for the maximal
# expressiveness. Here is a user-defined message function that is
# equivalent to ``fn.u_mul_e('h', 'w', 'm')``.
315
316
#

317
318

def u_mul_e_udf(edges):
319
    return {"m": edges.src["h"] * edges.data["w"]}
320
321
322
323
324
325


######################################################################
# ``edges`` has three members: ``src``, ``data`` and ``dst``, representing
# the source node feature, edge feature, and destination node feature for
# all edges.
326
#
327
328
329
330


######################################################################
# You can also write your own reduce function. For example, the following
331
# is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
332
# the incoming messages:
333
334
#

335

336
def mean_udf(nodes):
337
    return {"h_N": nodes.mailbox["m"].mean(1)}
338
339
340
341


######################################################################
# In short, DGL will group the nodes by their in-degrees, and for each
342
# group DGL stacks the incoming messages along the second dimension. You
343
344
# can then perform a reduction along the second dimension to aggregate
# messages.
345
#
346
347
348
# For more details on customizing message and reduce function with
# user-defined function, please refer to the :ref:`API
# reference <apiudf>`.
349
#
350
351
352
353
354


######################################################################
# Best practice of writing custom GNN modules
# -------------------------------------------
355
#
356
# DGL recommends the following practice ranked by preference:
357
#
358
359
360
361
362
363
# -  Use ``dgl.nn`` modules.
# -  Use ``dgl.nn.functional`` functions which contain lower-level complex
#    operations such as computing a softmax for each node over incoming
#    edges.
# -  Use ``update_all`` with builtin message and reduce functions.
# -  Use user-defined message or reduce functions.
364
#
365
366
367
368
369


######################################################################
# What’s next?
# ------------
370
#
371
372
# -  :ref:`Writing Efficient Message Passing
#    Code <guide-message-passing-efficient>`.
373
#
374

375

376
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
377
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'