L1_large_node_classification.py 11.4 KB
Newer Older
1
2
3
4
5
"""
Training GNN with Neighbor Sampling for Node Classification
===========================================================

This tutorial shows how to train a multi-layer GraphSAGE for node
6
7
8
classification on ``ogbn-arxiv`` provided by `Open Graph
Benchmark (OGB) <https://ogb.stanford.edu/>`__. The dataset contains around
170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

By the end of this tutorial, you will be able to

-  Train a GNN model for node classification on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>`.

"""


######################################################################
# Loading Dataset
# ---------------
#
# OGB already prepared the data as DGL graph.
#

import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

33
34
dataset = DglNodePropPredDataset('ogbn-arxiv')
device = 'cpu'      # change to 'cuda' for GPU
35
36
37


######################################################################
38
39
# OGB dataset is a collection of graphs and their labels. ``ogbn-arxiv``
# dataset only contains a single graph. So you can
40
41
42
43
# simply get the graph and its node labels like this:
#

graph, node_labels = dataset[0]
44
45
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
graph.ndata['label'] = node_labels[:, 0]
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)


######################################################################
# You can get the training-validation-test split of the nodes with
# ``get_split_idx`` method.
#

idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']


######################################################################
# How DGL Handles Computation Dependency
# --------------------------------------
#
# In the :doc:`previous tutorial <L0_neighbor_sampling_overview>`, you
# have seen that the computation dependency for message passing of a
73
# single node can be described as a series of *message flow graphs* (MFG).
74
75
76
77
78
79
80
81
82
83
84
85
86
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# DGL provides tools to iterate over the dataset in minibatches
# while generating the computation dependencies to compute their outputs
87
# with the MFGs above. For node classification, you can use
88
89
# ``dgl.dataloading.NodeDataLoader`` for iterating over the dataset.
# It accepts a sampler object to control how to generate the computation
90
# dependencies in the form of MFGs.  DGL provides
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# implementations of common sampling algorithms such as
# ``dgl.dataloading.MultiLayerNeighborSampler`` which randomly picks
# a fixed number of neighbors for each node.
#
# .. note::
#
#    To write your own neighbor sampler, please refer to :ref:`this user
#    guide section <guide-minibatch-customizing-neighborhood-sampler>`.
#
# The syntax of ``dgl.dataloading.NodeDataLoader`` is mostly similar to a
# PyTorch ``DataLoader``, with the addition that it needs a graph to
# generate computation dependency from, a set of node IDs to iterate on,
# and the neighbor sampler you defined.
#
# Let’s say that each node will gather messages from 4 neighbors on each
# layer. The code defining the data loader and neighbor sampler will look
# like the following.
#

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    # The following arguments are specific to NodeDataLoader.
    graph,              # The graph
    train_nids,         # The node IDs to iterate over in minibatches
    sampler,            # The neighbor sampler
116
    device=device,      # Put the sampled MFGs on CPU or GPU
117
118
119
120
121
122
123
124
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=1024,    # Batch size
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=False,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)


125
126
127
128
129
130
131
132
133
######################################################################
# .. note::
#
#    Since DGL 0.7 neighborhood sampling on GPU is supported.  Please
#    refer to :ref:`guide-minibatch-gpu-sampling` if you are
#    interested.
#


134
135
136
137
######################################################################
# You can iterate over the data loader and see what it yields.
#

138
input_nodes, output_nodes, mfgs = example_minibatch = next(iter(train_dataloader))
139
140
141
142
143
144
145
146
147
148
149
print(example_minibatch)
print("To compute {} nodes' outputs, we need {} nodes' input features".format(len(output_nodes), len(input_nodes))) 


######################################################################
# ``NodeDataLoader`` gives us three items per iteration.
#
# -  An ID tensor for the input nodes, i.e., nodes whose input features
#    are needed on the first GNN layer for this minibatch.
# -  An ID tensor for the output nodes, i.e. nodes whose representations
#    are to be computed.
150
# -  A list of MFGs storing the computation dependencies
151
152
153
154
155
#    for each GNN layer.
#


######################################################################
156
157
# You can get the source and destination node IDs of the MFGs
# and verify that the first few source nodes are always the same as the destination
158
# nodes.  As we described in the :doc:`overview <L0_neighbor_sampling_overview>`,
159
# destination nodes' own features from the previous layer may also be necessary in
160
161
162
# the computation of the new features.
#

163
164
165
166
167
mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst = mfgs[0].dstdata[dgl.NID]
print(mfg_0_src)
print(mfg_0_dst)
print(torch.equal(mfg_0_src[:mfgs[0].num_dst_nodes()], mfg_0_dst))
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188


######################################################################
# Defining Model
# --------------
#
# Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The
# model can be written as follows:
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
        self.h_feats = h_feats

189
    def forward(self, mfgs, x):
190
191
        # Lines that are changed are marked with an arrow: "<---"

192
193
        h_dst = x[:mfgs[0].num_dst_nodes()]  # <---
        h = self.conv1(mfgs[0], (x, h_dst))  # <---
194
        h = F.relu(h)
195
196
        h_dst = h[:mfgs[1].num_dst_nodes()]  # <---
        h = self.conv2(mfgs[1], (h, h_dst))  # <---
197
198
        return h

199
model = Model(num_features, 128, num_classes).to(device)
200
201
202
203


######################################################################
# If you compare against the code in the
204
# :doc:`introduction <../blitz/1_introduction>`, you will notice several
205
206
# differences:
#
207
# -  **DGL GNN layers on MFGs**. Instead of computing on the
208
209
210
211
212
213
#    full graph:
#
#    .. code:: python
#
#       h = self.conv1(g, x)
#
214
#    you only compute on the sampled MFG:
215
216
217
#
#    .. code:: python
#
218
#       h = self.conv1(mfgs[0], (x, h_dst))
219
#
220
221
222
#    All DGL’s GNN modules support message passing on MFGs,
#    where you supply a pair of features, one for source nodes and another
#    for destination nodes.
223
224
225
#
# -  **Feature slicing for self-dependency**. There are statements that
#    perform slicing to obtain the previous-layer representation of the
226
#     nodes:
227
228
229
#
#    .. code:: python
#
230
#       h_dst = x[:mfgs[0].num_dst_nodes()]
231
#
232
233
#    ``num_dst_nodes`` method works with MFGs, where it will
#    return the number of destination nodes.
234
#
235
236
237
#    Since the first few source nodes of the yielded MFG are
#    always the same as the destination nodes, these statements obtain the
#    representations of the destination nodes on the previous layer. They are
238
239
240
241
242
243
#    then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer.
#
# .. note::
#
#    See the :doc:`custom message passing
#    tutorial <L4_message_passing>` for more details on how to
244
#    manipulate MFGs produced in this way, such as the usage
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#    of ``num_dst_nodes``.
#


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

opt = torch.optim.Adam(model.parameters())


######################################################################
# When computing the validation score for model selection, usually you can
# also do neighbor sampling. To do that, you need to define another data
# loader.
#

valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
270
271
    num_workers=0,
    device=device
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
)


######################################################################
# The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file.
#

import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()

    with tqdm.tqdm(train_dataloader) as tq:
289
        for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
290
            # feature copy from CPU to GPU takes place here
291
292
            inputs = mfgs[0].srcdata['feat']
            labels = mfgs[-1].dstdata['label']
293

294
            predictions = model(mfgs, inputs)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())

            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)

    model.eval()

    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
310
311
312
313
        for input_nodes, output_nodes, mfgs in tq:
            inputs = mfgs[0].srcdata['feat']
            labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
            predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial do not train the whole model to the end.
        break


######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# with neighbor sampling.
#
# What’s next?
# ------------
#
# -  :doc:`Stochastic training of GNN for link
#    prediction <L2_large_link_prediction>`.
# -  :doc:`Adapting your custom GNN module for stochastic
#    training <L4_message_passing>`.
# -  During inference you may wish to disable neighbor sampling. If so,
#    please refer to the :ref:`user guide on exact offline
#    inference <guide-minibatch-inference>`.
#


346
347
# Thumbnail Courtesy: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'