L2_large_link_prediction.py 15.9 KB
Newer Older
1
2
3
4
5
"""
Stochastic Training of GNN for Link Prediction
==============================================

This tutorial will show how to train a multi-layer GraphSAGE for link
6
prediction on ``ogbn-arxiv`` provided by `Open Graph Benchmark
7
(OGB) <https://ogb.stanford.edu/>`__. The dataset
8
contains around 170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

By the end of this tutorial, you will be able to

-  Train a GNN model for link prediction on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.

"""


######################################################################
# Link Prediction Overview
# ------------------------
#
# Link prediction requires the model to predict the probability of
# existence of an edge. This tutorial does so by computing a dot product
# between the representations of both incident nodes.
#
# .. math::
#
#
#    \hat{y}_{u\sim v} = \sigma(h_u^T h_v)
#
# It then minimizes the following binary cross entropy loss.
#
# .. math::
#
#
#    \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
#
# This is identical to the link prediction formulation in :doc:`the previous
43
# tutorial on link prediction <../blitz/4_link_predict>`.
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#


######################################################################
# Loading Dataset
# ---------------
#
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#

import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
60
61
dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu"  # change to 'cuda' for GPU
62
63

graph, node_labels = dataset[0]
64
65
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
66
67
68
print(graph)
print(node_labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
node_features = graph.ndata["feat"]
70
71
72
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
73
print("Number of classes:", num_classes)
74
75

idx_split = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
76
77
78
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]
79
80
81
82
83
84
85


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
86
# graph <../blitz/4_link_predict>`, a common practice to train GNN on large graphs is
87
88
89
90
91
92
93
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
94
# DGL provides ``dgl.dataloading.as_edge_prediction_sampler`` to
95
96
97
98
99
100
101
102
103
104
105
106
107
# iterate over edges for edge classification or link prediction tasks.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
# ``dgl.dataloading.negative_sampler.Uniform``.  Here this tutorial uniformly
# draws 5 negative examples per positive example.
#

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)


######################################################################
# After defining the negative sampler, one can then define the edge data
108
# loader with neighbor sampling.  To create an ``DataLoader`` for
109
110
111
112
# link prediction, provide a neighbor sampler object as well as the negative
# sampler object created above.
#

113
114
sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
116
    sampler, negative_sampler=negative_sampler
)
117
118
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
119
    graph,  # The graph
120
    torch.arange(graph.number_of_edges()),  # The edges to iterate over
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
    sampler,  # The neighbor sampler
    device=device,  # Put the MFGs on CPU or GPU
123
    # The following arguments are inherited from PyTorch DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
124
125
126
127
    batch_size=1024,  # Batch size
    shuffle=True,  # Whether to shuffle the nodes for every epoch
    drop_last=False,  # Whether to drop the last incomplete batch
    num_workers=0,  # Number of sampler processes
128
129
130
131
132
133
134
135
)


######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#

136
input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
137
138
139
140
141
142
143
144
145
146
147
148
149
print("Number of input nodes:", len(input_nodes))
print(
    "Positive graph # nodes:",
    pos_graph.number_of_nodes(),
    "# edges:",
    pos_graph.number_of_edges(),
)
print(
    "Negative graph # nodes:",
    neg_graph.number_of_nodes(),
    "# edges:",
    neg_graph.number_of_edges(),
)
150
print(mfgs)
151
152
153
154
155
156
157
158
159
160
161


######################################################################
# The example minibatch consists of four elements.
#
# The first element is an ID tensor for the input nodes, i.e., nodes
# whose input features are needed on the first GNN layer for this minibatch.
#
# The second element and the third element are the positive graph and the
# negative graph for this minibatch.
# The concept of positive and negative graphs have been introduced in the
162
# :doc:`full-graph link prediction tutorial <../blitz/4_link_predict>`.  In minibatch
163
164
165
166
# training, the positive graph and the negative graph only contain nodes
# necessary for computing the pair-wise scores of positive and negative examples
# in the current minibatch.
#
167
168
169
# The last element is a list of :doc:`MFGs <L0_neighbor_sampling_overview>`
# storing the computation dependencies for each GNN layer.
# The MFGs are used to compute the GNN outputs of the nodes
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# involved in positive/negative graph.
#


######################################################################
# Defining Model for Node Representation
# --------------------------------------
#
# The model is almost identical to the one in the :doc:`node classification
# tutorial <L1_large_node_classification>`. The only difference is
# that since you are doing link prediction, the output dimension will not
# be the number of classes in the dataset.
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
188

189
190
191
class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
193
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type="mean")
194
195
        self.h_feats = h_feats

196
    def forward(self, mfgs, x):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
197
        h_dst = x[: mfgs[0].num_dst_nodes()]
198
        h = self.conv1(mfgs[0], (x, h_dst))
199
        h = F.relu(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
200
        h_dst = h[: mfgs[1].num_dst_nodes()]
201
        h = self.conv2(mfgs[1], (h, h_dst))
202
203
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
204

205
model = Model(num_features, 128).to(device)
206
207
208
209
210
211
212
213
214
215
216


######################################################################
# Defining the Score Predictor for Edges
# --------------------------------------
#
# After getting the node representation necessary for the minibatch, the
# last thing to do is to predict the score of the edges and non-existent
# edges in the sampled minibatch.
#
# The following score predictor, copied from the :doc:`link prediction
217
# tutorial <../blitz/4_link_predict>`, takes a dot product between the
218
219
220
221
222
# incident nodes’ representations.
#

import dgl.function as fn

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
223

224
225
226
class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
            g.ndata["h"] = h
228
229
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
230
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
231
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
232
            return g.edata["score"][:, 0]
233
234
235


######################################################################
236
237
# Evaluating Performance with Unsupervised Learning (Optional)
# ------------------------------------------------------------
238
239
240
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
241
242
243
244
245
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__.
# Basically, it first trains a GNN via link prediction, and get an embedding
# for each node.  Then it trains a downstream classifier on top of this
# embedding and compute the accuracy as an assessment of the embedding
# quality.
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#


######################################################################
# To obtain the representations of all the nodes, this tutorial uses
# neighbor sampling as introduced in the :doc:`node classification
# tutorial <L1_large_node_classification>`.
#
# .. note::
#
#    If you would like to obtain node representations without
#    neighbor sampling during inference, please refer to this :ref:`user
#    guide <guide-minibatch-inference>`.
#

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
261

262
263
264
265
def inference(model, graph, node_features):
    with torch.no_grad():
        nodes = torch.arange(graph.number_of_nodes())

266
267
        sampler = dgl.dataloading.NeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
268
269
270
            graph,
            torch.arange(graph.number_of_nodes()),
            sampler,
271
272
273
274
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
275
276
            device=device,
        )
277
278

        result = []
279
        for input_nodes, output_nodes, mfgs in train_dataloader:
280
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
281
            inputs = mfgs[0].srcdata["feat"]
282
            result.append(model(mfgs, inputs))
283
284
285

        return torch.cat(result)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
286

287
288
import sklearn.metrics

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
289

290
def evaluate(emb, label, train_nids, valid_nids, test_nids):
291
    classifier = nn.Linear(emb.shape[1], num_classes).to(device)
292
293
294
    opt = torch.optim.LBFGS(classifier.parameters())

    def compute_loss():
295
296
        pred = classifier(emb[train_nids].to(device))
        loss = F.cross_entropy(pred, label[train_nids].to(device))
297
298
299
300
301
302
303
304
        return loss

    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
305
    prev_loss = float("inf")
306
307
308
309
310
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
311
                print("Converges at iteration", i)
312
313
314
315
316
                break
            else:
                prev_loss = loss

    with torch.no_grad():
317
        pred = classifier(emb.to(device)).cpu()
318
        label = label
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
319
320
321
322
323
324
        valid_acc = sklearn.metrics.accuracy_score(
            label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1)
        )
        test_acc = sklearn.metrics.accuracy_score(
            label[test_nids].numpy(), pred[test_nids].numpy().argmax(1)
        )
325
326
327
328
329
330
331
332
333
334
    return valid_acc, test_acc


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

335
336
model = Model(node_features.shape[1], 128).to(device)
predictor = DotPredictor().to(device)
337
338
339
340
341
342
343
344
345
346
347
348
349
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))


######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
# validation set:
#

import tqdm
import sklearn.metrics

best_accuracy = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
350
best_model_path = "model.pt"
351
352
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
353
        for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
354
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
355
            inputs = mfgs[0].srcdata["feat"]
356

357
            outputs = model(mfgs, inputs)
358
359
360
361
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)

            score = torch.cat([pos_score, neg_score])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
362
363
364
            label = torch.cat(
                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]
            )
365
366
367
368
369
370
            loss = F.binary_cross_entropy_with_logits(score, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
371
            tq.set_postfix({"loss": "%.03f" % loss.item()}, refresh=False)
372

373
            if (step + 1) % 500 == 0:
374
375
                model.eval()
                emb = inference(model, graph, node_features)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
376
377
378
379
380
381
382
383
                valid_acc, test_acc = evaluate(
                    emb, node_labels, train_nids, valid_nids, test_nids
                )
                print(
                    "Epoch {} Validation Accuracy {} Test Accuracy {}".format(
                        epoch, valid_acc, test_acc
                    )
                )
384
385
386
387
388
389
390
391
392
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

                # Note that this tutorial do not train the whole model to the end.
                break


393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
######################################################################
# Evaluating Performance with Link Prediction (Optional)
# ------------------------------------------------------
#
# In practice, it is more common to evaluate the link prediction
# model to see whether it can predict new edges. There are different
# evaluation metrics such as
# `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
# or `various metrics from information retrieval <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)>`__.
# Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs.
#
# Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src``
# and ``test_neg_dst`` are ground truth node pairs without edges
# in between (or *negative* pairs).
#

# Positive pairs
413
414
415
416
417
# These are randomly generated as an example.  You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
418
419
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
)
420
421
# Negative pairs.  Likewise, you will need to replace them with your
# own ground truth.
422
test_neg_src = test_pos_src
423
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
424
425
426


######################################################################
427
428
# First you need to compute the node representations for all the nodes
# with the ``inference`` method above:
429
430
#

431
node_reprs = inference(model, graph, node_features)
432
433

######################################################################
434
435
436
# Since the predictor is a dot product, you can now easily compute the
# score of positive and negative test pairs to compute metrics such
# as AUC:
437
438
#

439
440
441
442
443
444
445
h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
446
447
448
449
450
test_labels = (
    torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)])
    .cpu()
    .numpy()
)
451
452

auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
453
print("Link Prediction AUC:", auc)
454
455


456
457
458
459
460
461
462
463
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#

464

465
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
466
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'