L1_large_node_classification.py 9.93 KB
Newer Older
1
2
3
4
5
"""
Training GNN with Neighbor Sampling for Node Classification
===========================================================

This tutorial shows how to train a multi-layer GraphSAGE for node
6
7
8
classification on ``ogbn-arxiv`` provided by `Open Graph
Benchmark (OGB) <https://ogb.stanford.edu/>`__. The dataset contains around
170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

By the end of this tutorial, you will be able to

-  Train a GNN model for node classification on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>`.

"""


######################################################################
# Loading Dataset
# ---------------
#
25
# `ogbn-arxiv` is already prepared as ``BuiltinDataset`` in GraphBolt.
26
#
27

28
import os
29
30

os.environ["DGLBACKEND"] = "pytorch"
31
import dgl
32
import dgl.graphbolt as gb
33
import numpy as np
34
import torch
35

36
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
device = "cpu"  # change to 'cuda' for GPU
38
39
40


######################################################################
41
42
43
44
45
# Dataset consists of graph, feature and tasks. You can get the
# training-validation-test set from the tasks. Seed nodes and corresponding
# labels are already stored in each training-validation-test set. Other
# metadata such as number of classes are also stored in the tasks. In this
# dataset, there is only one task: `node classification``.
46
47
#

48
49
50
51
52
53
54
55
graph = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
task_name = dataset.tasks[0].metadata["name"]
num_classes = dataset.tasks[0].metadata["num_classes"]
print(f"Task: {task_name}. Number of classes: {num_classes}")
56
57
58
59
60
61
62
63


######################################################################
# How DGL Handles Computation Dependency
# --------------------------------------
#
# In the :doc:`previous tutorial <L0_neighbor_sampling_overview>`, you
# have seen that the computation dependency for message passing of a
64
# single node can be described as a series of *message flow graphs* (MFG).
65
66
67
68
69
70
71
72
73
74
75
76
77
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# DGL provides tools to iterate over the dataset in minibatches
# while generating the computation dependencies to compute their outputs
78
# with the MFGs above. For node classification, you can use
79
80
81
82
83
84
85
# ``dgl.graphbolt.MultiProcessDataLoader`` for iterating over the dataset.
# It accepts a data pipe that generates minibatches of nodes and their
# labels, sample neighbors for each node, and generate the computation
# dependencies in the form of MFGs. Feature fetching, block creation and
# copying to target device are also supported. All these operations are
# split into separate stages in the data pipe, so that you can customize
# the data pipeline by inserting your own operations.
86
87
88
89
90
91
92
93
94
95
96
97
#
# .. note::
#
#    To write your own neighbor sampler, please refer to :ref:`this user
#    guide section <guide-minibatch-customizing-neighborhood-sampler>`.
#
#
# Let’s say that each node will gather messages from 4 neighbors on each
# layer. The code defining the data loader and neighbor sampler will look
# like the following.
#

98
99
100
101
102
103
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(graph, [4, 4])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
train_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
104
105


106
107
108
######################################################################
# .. note::
#
109
110
111
#    In this example, neighborhood sampling runs on CPU, If you are
#    interested in running it on GPU, please refer to
#    :ref:`guide-minibatch-gpu-sampling`.
112
113
114
#


115
######################################################################
116
117
# You can iterate over the data loader and a ``DGLMiniBatch`` object
# is yielded.
118
119
#

120
121
data = next(iter(train_dataloader))
print(data)
122
123
124


######################################################################
125
# You can get the input node IDs from MFGs.
126
127
#

128
129
130
mfgs = data.blocks
input_nodes = mfgs[0].srcdata[dgl.NID]
print(f"Input nodes: {input_nodes}.")
131
132
133
134
135
136
137
138
139
140
141
142
143

######################################################################
# Defining Model
# --------------
#
# Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The
# model can be written as follows:
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
144

145
146
147
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
149
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
150
151
        self.h_feats = h_feats

152
    def forward(self, mfgs, x):
153
154
        # Lines that are changed are marked with an arrow: "<---"

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
155
        h_dst = x[: mfgs[0].num_dst_nodes()]  # <---
156
        h = self.conv1(mfgs[0], (x, h_dst))  # <---
157
        h = F.relu(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
158
        h_dst = h[: mfgs[1].num_dst_nodes()]  # <---
159
        h = self.conv2(mfgs[1], (h, h_dst))  # <---
160
161
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
162

163
164
in_size = feature.size("node", None, "feat")[0]
model = Model(in_size, 64, num_classes).to(device)
165
166
167
168


######################################################################
# If you compare against the code in the
169
# :doc:`introduction <../blitz/1_introduction>`, you will notice several
170
171
# differences:
#
172
# -  **DGL GNN layers on MFGs**. Instead of computing on the
173
174
175
176
177
178
#    full graph:
#
#    .. code:: python
#
#       h = self.conv1(g, x)
#
179
#    you only compute on the sampled MFG:
180
181
182
#
#    .. code:: python
#
183
#       h = self.conv1(mfgs[0], (x, h_dst))
184
#
185
186
187
#    All DGL’s GNN modules support message passing on MFGs,
#    where you supply a pair of features, one for source nodes and another
#    for destination nodes.
188
189
190
#
# -  **Feature slicing for self-dependency**. There are statements that
#    perform slicing to obtain the previous-layer representation of the
191
#     nodes:
192
193
194
#
#    .. code:: python
#
195
#       h_dst = x[:mfgs[0].num_dst_nodes()]
196
#
197
198
#    ``num_dst_nodes`` method works with MFGs, where it will
#    return the number of destination nodes.
199
#
200
201
202
#    Since the first few source nodes of the yielded MFG are
#    always the same as the destination nodes, these statements obtain the
#    representations of the destination nodes on the previous layer. They are
203
204
205
206
207
208
#    then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer.
#
# .. note::
#
#    See the :doc:`custom message passing
#    tutorial <L4_message_passing>` for more details on how to
209
#    manipulate MFGs produced in this way, such as the usage
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#    of ``num_dst_nodes``.
#


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

opt = torch.optim.Adam(model.parameters())


######################################################################
# When computing the validation score for model selection, usually you can
# also do neighbor sampling. To do that, you need to define another data
# loader.
#

230
231
232
233
234
235
datapipe = gb.ItemSampler(valid_set, batch_size=1024, shuffle=False)
datapipe = datapipe.sample_neighbor(graph, [4, 4])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
valid_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
236
237


238
239
import sklearn.metrics

240
241
242
243
244
######################################################################
# The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file.
#

245
import tqdm
246
247

best_accuracy = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
248
best_model_path = "model.pt"
249
250
251
252
for epoch in range(10):
    model.train()

    with tqdm.tqdm(train_dataloader) as tq:
253
254
255
        for step, data in enumerate(tq):
            x = data.node_features["feat"]
            labels = data.labels
256

257
            predictions = model(data.blocks, x)
258
259
260
261
262
263

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
264
265
266
267
            accuracy = sklearn.metrics.accuracy_score(
                labels.cpu().numpy(),
                predictions.argmax(1).detach().cpu().numpy(),
            )
268

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
269
270
271
272
            tq.set_postfix(
                {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
                refresh=False,
            )
273
274
275
276
277
278

    model.eval()

    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
279
280
281
282
        for data in tq:
            x = data.node_features["feat"]
            labels.append(data.labels.cpu().numpy())
            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())
283
284
285
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
286
        print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial do not train the whole model to the end.
        break


######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# with neighbor sampling.
#
# What’s next?
# ------------
#
# -  :doc:`Stochastic training of GNN for link
#    prediction <L2_large_link_prediction>`.
# -  :doc:`Adapting your custom GNN module for stochastic
#    training <L4_message_passing>`.
# -  During inference you may wish to disable neighbor sampling. If so,
#    please refer to the :ref:`user guide on exact offline
#    inference <guide-minibatch-inference>`.
#