"src/vscode:/vscode.git/clone" did not exist on "67e2f95cc4ff8c25f4d04f8bab46df02216527b2"
L1_large_node_classification.py 11.6 KB
Newer Older
1
2
3
4
5
"""
Training GNN with Neighbor Sampling for Node Classification
===========================================================

This tutorial shows how to train a multi-layer GraphSAGE for node
6
7
8
classification on ``ogbn-arxiv`` provided by `Open Graph
Benchmark (OGB) <https://ogb.stanford.edu/>`__. The dataset contains around
170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

By the end of this tutorial, you will be able to

-  Train a GNN model for node classification on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>`.

"""


######################################################################
# Loading Dataset
# ---------------
#
# OGB already prepared the data as DGL graph.
#
27
exit(0)
28
import os
29
30

os.environ["DGLBACKEND"] = "pytorch"
31
32
import dgl
import numpy as np
33
import torch
34
35
from ogb.nodeproppred import DglNodePropPredDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
37
dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu"  # change to 'cuda' for GPU
38
39
40


######################################################################
41
42
# OGB dataset is a collection of graphs and their labels. ``ogbn-arxiv``
# dataset only contains a single graph. So you can
43
44
45
46
# simply get the graph and its node labels like this:
#

graph, node_labels = dataset[0]
47
48
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
49
graph.ndata["label"] = node_labels[:, 0]
50
51
52
print(graph)
print(node_labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
53
node_features = graph.ndata["feat"]
54
55
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
print("Number of classes:", num_classes)
57
58
59
60
61
62
63
64


######################################################################
# You can get the training-validation-test split of the nodes with
# ``get_split_idx`` method.
#

idx_split = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
66
67
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]
68
69
70
71
72
73
74
75


######################################################################
# How DGL Handles Computation Dependency
# --------------------------------------
#
# In the :doc:`previous tutorial <L0_neighbor_sampling_overview>`, you
# have seen that the computation dependency for message passing of a
76
# single node can be described as a series of *message flow graphs* (MFG).
77
78
79
80
81
82
83
84
85
86
87
88
89
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# DGL provides tools to iterate over the dataset in minibatches
# while generating the computation dependencies to compute their outputs
90
# with the MFGs above. For node classification, you can use
91
# ``dgl.dataloading.DataLoader`` for iterating over the dataset.
92
# It accepts a sampler object to control how to generate the computation
93
# dependencies in the form of MFGs.  DGL provides
94
# implementations of common sampling algorithms such as
95
# ``dgl.dataloading.NeighborSampler`` which randomly picks
96
97
98
99
100
101
102
# a fixed number of neighbors for each node.
#
# .. note::
#
#    To write your own neighbor sampler, please refer to :ref:`this user
#    guide section <guide-minibatch-customizing-neighborhood-sampler>`.
#
103
# The syntax of ``dgl.dataloading.DataLoader`` is mostly similar to a
104
105
106
107
108
109
110
111
112
# PyTorch ``DataLoader``, with the addition that it needs a graph to
# generate computation dependency from, a set of node IDs to iterate on,
# and the neighbor sampler you defined.
#
# Let’s say that each node will gather messages from 4 neighbors on each
# layer. The code defining the data loader and neighbor sampler will look
# like the following.
#

113
114
115
sampler = dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DGL's DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116
117
118
119
    graph,  # The graph
    train_nids,  # The node IDs to iterate over in minibatches
    sampler,  # The neighbor sampler
    device=device,  # Put the sampled MFGs on CPU or GPU
120
    # The following arguments are inherited from PyTorch DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
123
124
    batch_size=1024,  # Batch size
    shuffle=True,  # Whether to shuffle the nodes for every epoch
    drop_last=False,  # Whether to drop the last incomplete batch
    num_workers=0,  # Number of sampler processes
125
126
127
)


128
129
130
131
132
133
134
135
136
######################################################################
# .. note::
#
#    Since DGL 0.7 neighborhood sampling on GPU is supported.  Please
#    refer to :ref:`guide-minibatch-gpu-sampling` if you are
#    interested.
#


137
138
139
140
######################################################################
# You can iterate over the data loader and see what it yields.
#

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
141
142
143
input_nodes, output_nodes, mfgs = example_minibatch = next(
    iter(train_dataloader)
)
144
print(example_minibatch)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
145
146
147
148
149
print(
    "To compute {} nodes' outputs, we need {} nodes' input features".format(
        len(output_nodes), len(input_nodes)
    )
)
150
151
152


######################################################################
153
# DGL's ``DataLoader`` gives us three items per iteration.
154
155
156
157
158
#
# -  An ID tensor for the input nodes, i.e., nodes whose input features
#    are needed on the first GNN layer for this minibatch.
# -  An ID tensor for the output nodes, i.e. nodes whose representations
#    are to be computed.
159
# -  A list of MFGs storing the computation dependencies
160
161
162
163
164
#    for each GNN layer.
#


######################################################################
165
166
# You can get the source and destination node IDs of the MFGs
# and verify that the first few source nodes are always the same as the destination
167
# nodes.  As we described in the :doc:`overview <L0_neighbor_sampling_overview>`,
168
# destination nodes' own features from the previous layer may also be necessary in
169
170
171
# the computation of the new features.
#

172
173
174
175
mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst = mfgs[0].dstdata[dgl.NID]
print(mfg_0_src)
print(mfg_0_dst)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
176
print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst))
177
178
179
180
181
182
183
184
185
186
187
188
189
190


######################################################################
# Defining Model
# --------------
#
# Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The
# model can be written as follows:
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191

192
193
194
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
196
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
197
198
        self.h_feats = h_feats

199
    def forward(self, mfgs, x):
200
201
        # Lines that are changed are marked with an arrow: "<---"

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
202
        h_dst = x[: mfgs[0].num_dst_nodes()]  # <---
203
        h = self.conv1(mfgs[0], (x, h_dst))  # <---
204
        h = F.relu(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
205
        h_dst = h[: mfgs[1].num_dst_nodes()]  # <---
206
        h = self.conv2(mfgs[1], (h, h_dst))  # <---
207
208
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
209

210
model = Model(num_features, 128, num_classes).to(device)
211
212
213
214


######################################################################
# If you compare against the code in the
215
# :doc:`introduction <../blitz/1_introduction>`, you will notice several
216
217
# differences:
#
218
# -  **DGL GNN layers on MFGs**. Instead of computing on the
219
220
221
222
223
224
#    full graph:
#
#    .. code:: python
#
#       h = self.conv1(g, x)
#
225
#    you only compute on the sampled MFG:
226
227
228
#
#    .. code:: python
#
229
#       h = self.conv1(mfgs[0], (x, h_dst))
230
#
231
232
233
#    All DGL’s GNN modules support message passing on MFGs,
#    where you supply a pair of features, one for source nodes and another
#    for destination nodes.
234
235
236
#
# -  **Feature slicing for self-dependency**. There are statements that
#    perform slicing to obtain the previous-layer representation of the
237
#     nodes:
238
239
240
#
#    .. code:: python
#
241
#       h_dst = x[:mfgs[0].num_dst_nodes()]
242
#
243
244
#    ``num_dst_nodes`` method works with MFGs, where it will
#    return the number of destination nodes.
245
#
246
247
248
#    Since the first few source nodes of the yielded MFG are
#    always the same as the destination nodes, these statements obtain the
#    representations of the destination nodes on the previous layer. They are
249
250
251
252
253
254
#    then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer.
#
# .. note::
#
#    See the :doc:`custom message passing
#    tutorial <L4_message_passing>` for more details on how to
255
#    manipulate MFGs produced in this way, such as the usage
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#    of ``num_dst_nodes``.
#


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

opt = torch.optim.Adam(model.parameters())


######################################################################
# When computing the validation score for model selection, usually you can
# also do neighbor sampling. To do that, you need to define another data
# loader.
#

276
valid_dataloader = dgl.dataloading.DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
277
278
279
    graph,
    valid_nids,
    sampler,
280
281
282
    batch_size=1024,
    shuffle=False,
    drop_last=False,
283
    num_workers=0,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
284
    device=device,
285
286
287
)


288
289
import sklearn.metrics

290
291
292
293
294
######################################################################
# The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file.
#

295
import tqdm
296
297

best_accuracy = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
298
best_model_path = "model.pt"
299
300
301
302
for epoch in range(10):
    model.train()

    with tqdm.tqdm(train_dataloader) as tq:
303
        for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
304
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
305
306
            inputs = mfgs[0].srcdata["feat"]
            labels = mfgs[-1].dstdata["label"]
307

308
            predictions = model(mfgs, inputs)
309
310
311
312
313
314

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
315
316
317
318
            accuracy = sklearn.metrics.accuracy_score(
                labels.cpu().numpy(),
                predictions.argmax(1).detach().cpu().numpy(),
            )
319

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
320
321
322
323
            tq.set_postfix(
                {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
                refresh=False,
            )
324
325
326
327
328
329

    model.eval()

    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
330
        for input_nodes, output_nodes, mfgs in tq:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
331
332
            inputs = mfgs[0].srcdata["feat"]
            labels.append(mfgs[-1].dstdata["label"].cpu().numpy())
333
            predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
334
335
336
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
337
        print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial do not train the whole model to the end.
        break


######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# with neighbor sampling.
#
# What’s next?
# ------------
#
# -  :doc:`Stochastic training of GNN for link
#    prediction <L2_large_link_prediction>`.
# -  :doc:`Adapting your custom GNN module for stochastic
#    training <L4_message_passing>`.
# -  During inference you may wish to disable neighbor sampling. If so,
#    please refer to the :ref:`user guide on exact offline
#    inference <guide-minibatch-inference>`.
#


366
# Thumbnail credits: Stanford CS224W Notes
367
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'