L4_message_passing.py 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
Writing GNN Modules for Stochastic GNN Training
===============================================

All GNN modules DGL provides support stochastic GNN training. This
tutorial teaches you how to write your own graph neural network module
for stochastic GNN training. It assumes that

1. You know :doc:`how to write GNN modules for full graph
10
   training <../blitz/3_message_passing>`.
11
12
13
14
15
16
17
18
19
20
2. You know :doc:`how stochastic GNN training pipeline
   works <L1_large_node_classification>`.

"""

import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

21
22
dataset = DglNodePropPredDataset('ogbn-arxiv')
device = 'cpu'      # change to 'cuda' for GPU
23
24

graph, node_labels = dataset[0]
25
26
27
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0]
28
29
30
31
32
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
node_features = graph.ndata['feat']

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
33
train_dataloader = dgl.dataloading.DataLoader(
34
35
36
37
38
39
40
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

41
input_nodes, output_nodes, mfgs = next(iter(train_dataloader))
42
43
44
45
46
47


######################################################################
# DGL Bipartite Graph Introduction
# --------------------------------
#
48
49
50
# In the previous tutorials, you have seen the concept *message flow graph*
# (MFG), where nodes are divided into two parts.  It is a kind of (directional)
# bipartite graph.
51
52
53
# This section introduces how you can manipulate (directional) bipartite
# graphs.
#
54
# You can access the source node features and destination node features via
55
56
57
# ``srcdata`` and ``dstdata`` attributes:
#

58
59
60
mfg = mfgs[0]
print(mfg.srcdata)
print(mfg.dstdata)
61
62
63
64


######################################################################
# It also has ``num_src_nodes`` and ``num_dst_nodes`` functions to query
65
# how many source nodes and destination nodes exist in the bipartite graph:
66
67
#

68
print(mfg.num_src_nodes(), mfg.num_dst_nodes())
69
70
71
72
73
74
75


######################################################################
# You can assign features to ``srcdata`` and ``dstdata`` just as what you
# will do with ``ndata`` on the graphs you have seen earlier:
#

76
77
mfg.srcdata['x'] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes())
dst_feat = mfg.dstdata['feat']
78
79
80
81


######################################################################
# Also, since the bipartite graphs are constructed by DGL, you can
82
83
# retrieve the source node IDs (i.e. those that are required to compute the
# output) and destination node IDs (i.e. those whose representations the
84
85
86
# current GNN layer should compute) as follows.
#

87
mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID]
88
89
90
91
92
93
94
95
96


######################################################################
# Writing GNN Modules for Bipartite Graphs for Stochastic Training
# ----------------------------------------------------------------
#


######################################################################
97
98
# Recall that the MFGs yielded by the ``DataLoader``
# have the property that the first few source nodes are
99
# always identical to the destination nodes:
100
101
102
103
104
105
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#

106
print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NID]))
107
108
109


######################################################################
110
# Suppose you have obtained the source node representations
111
112
113
# :math:`h_u^{(l-1)}`:
#

114
mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10)
115
116
117
118
119


######################################################################
# Recall that DGL provides the `update_all` interface for expressing how
# to compute messages and how to aggregate them on the nodes that receive
120
# them. This concept naturally applies to bipartite graphs like MFGs -- message
121
122
123
124
125
126
127
128
129
130
131
132
# computation happens on the edges between source and destination nodes of
# the edges, and message aggregation happens on the destination nodes.
#
# For example, suppose the message function copies the source feature
# (i.e. :math:`M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) = h_v^{(l-1)}`),
# and the reduce function averages the received messages.  Performing
# such message passing computation on a bipartite graph is no different than
# on a full graph:
#

import dgl.function as fn

133
134
mfg.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h'))
m_v = mfg.dstdata['h']
135
136
137
138
139
140
m_v


######################################################################
# Putting them together, you can implement a GraphSAGE convolution for
# training with neighbor sampling as follows (the differences to the :doc:`full graph
141
# counterpart <../blitz/3_message_passing>` are highlighted with arrows ``<---``)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#

import torch.nn as nn
import torch.nn.functional as F
import tqdm

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
169
            The input MFG.
170
        h : (Tensor, Tensor)
171
            The feature of source nodes and destination nodes as a pair of Tensors.
172
173
174
175
176
177
        """
        with g.local_scope():
            h_src, h_dst = h
            g.srcdata['h'] = h_src                        # <---
            g.dstdata['h'] = h_dst                        # <---
            # update_all is a message passing API.
178
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_N'))
179
180
181
182
183
184
185
186
187
188
            h_N = g.dstdata['h_N']
            h_total = torch.cat([h_dst, h_N], dim=1)      # <---
            return self.linear(h_total)

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

189
190
191
    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
192
        h = F.relu(h)
193
194
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
195
196
197
        return h

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
198
train_dataloader = dgl.dataloading.DataLoader(
199
    graph, train_nids, sampler,
200
    device=device,
201
202
203
204
205
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)
206
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).to(device)
207
208

with tqdm.tqdm(train_dataloader) as tq:
209
210
211
212
    for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
        inputs = mfgs[0].srcdata['feat']
        labels = mfgs[-1].dstdata['label']
        predictions = model(mfgs, inputs)
213
214
215
216


######################################################################
# Both ``update_all`` and the functions in ``nn.functional`` namespace
217
# support MFGs, so you can migrate the code working for small
218
219
220
221
222
223
224
225
226
# graphs to large graph training with minimal changes introduced above.
#


######################################################################
# Writing GNN Modules for Both Full-graph Training and Stochastic Training
# ------------------------------------------------------------------------
#
# Here is a step-by-step tutorial for writing a GNN module for both
227
# :doc:`full-graph training <../blitz/1_introduction>` *and* :doc:`stochastic
228
# training <L1_large_node_classification>`.
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
#
# Say you start with a GNN module that works for full-graph training only:
#

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super().__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


######################################################################
# **First step**: Check whether the input feature is a single tensor or a
# pair of tensors:
#
# .. code:: python
#
#    if isinstance(h, tuple):
#        h_src, h_dst = h
#    else:
#        h_src = h_dst = h
#
# **Second step**: Replace node features ``h`` with ``h_src`` or
# ``h_dst``, and assign the node features to ``srcdata`` or ``dstdata``,
# instead of ``ndata``.
#
# Whether to assign to ``srcdata`` or ``dstdata`` depends on whether the
# said feature acts as the features on source nodes or destination nodes
# of the edges in the message functions (in ``update_all`` or
# ``apply_edges``).
#
# *Example 1*: For the following ``update_all`` statement:
#
# .. code:: python
#
#    g.ndata['h'] = h
#    g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
#
# The node feature ``h`` acts as source node feature because ``'h'``
# appeared as source node feature. So you will need to replace ``h`` with
# source feature ``h_src`` and assign to ``srcdata`` for the version that
# works with both cases:
#
# .. code:: python
#
#    g.srcdata['h'] = h_src
#    g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
#
# *Example 2*: For the following ``apply_edges`` statement:
#
# .. code:: python
#
#    g.ndata['h'] = h
#    g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
#
# The node feature ``h`` acts as both source node feature and destination
# node feature. So you will assign ``h_src`` to ``srcdata`` and ``h_dst``
# to ``dstdata``:
#
# .. code:: python
#
#    g.srcdata['h'] = h_src
#    g.dstdata['h'] = h_dst
#    # The first 'h' corresponds to source feature (u) while the second 'h' corresponds to destination feature (v).
#    g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
#
# .. note::
#
#    For homogeneous graphs (i.e. graphs with only one node type
#    and one edge type), ``srcdata`` and ``dstdata`` are aliases of
#    ``ndata``. So you can safely replace ``ndata`` with ``srcdata`` and
#    ``dstdata`` even for full-graph training.
#
# **Third step**: Replace the ``ndata`` for outputs with ``dstdata``.
#
# For example, the following code
#
# .. code:: python
#
#    # Assume that update_all() function has been called with output node features in `h_N`.
#    h_N = g.ndata['h_N']
#    h_total = torch.cat([h, h_N], dim=1)
#
# will change to
#
# .. code:: python
#
#    h_N = g.dstdata['h_N']
#    h_total = torch.cat([h_dst, h_N], dim=1)
#


######################################################################
# Putting together, you will change the ``SAGEConvForBoth`` module above
# to something like the following:
#

class SAGEConvForBoth(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super().__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor or tuple[Tensor, Tensor]
            The input node feature.
        """
        with g.local_scope():
            if isinstance(h, tuple):
                h_src, h_dst = h
            else:
                h_src = h_dst = h

            g.srcdata['h'] = h_src
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h_dst, h_N], dim=1)
            return self.linear(h_total)

391

392
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
393
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'