L2_large_link_prediction.py 15.9 KB
Newer Older
1
2
3
4
5
"""
Stochastic Training of GNN for Link Prediction
==============================================

This tutorial will show how to train a multi-layer GraphSAGE for link
6
prediction on ``ogbn-arxiv`` provided by `Open Graph Benchmark
7
(OGB) <https://ogb.stanford.edu/>`__. The dataset
8
contains around 170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

By the end of this tutorial, you will be able to

-  Train a GNN model for link prediction on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.

"""


######################################################################
# Link Prediction Overview
# ------------------------
#
# Link prediction requires the model to predict the probability of
# existence of an edge. This tutorial does so by computing a dot product
# between the representations of both incident nodes.
#
# .. math::
#
#
#    \hat{y}_{u\sim v} = \sigma(h_u^T h_v)
#
# It then minimizes the following binary cross entropy loss.
#
# .. math::
#
#
#    \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
#
# This is identical to the link prediction formulation in :doc:`the previous
43
# tutorial on link prediction <../blitz/4_link_predict>`.
44
45
46
47
48
49
50
51
52
53
54
#


######################################################################
# Loading Dataset
# ---------------
#
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#

55
56
import os
os.environ['DGLBACKEND'] = 'pytorch'
57
58
59
60
61
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
63
dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu"  # change to 'cuda' for GPU
64
65

graph, node_labels = dataset[0]
66
67
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
68
69
70
print(graph)
print(node_labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
71
node_features = graph.ndata["feat"]
72
73
74
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
print("Number of classes:", num_classes)
76
77

idx_split = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
78
79
80
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]
81
82
83
84
85
86
87


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
88
# graph <../blitz/4_link_predict>`, a common practice to train GNN on large graphs is
89
90
91
92
93
94
95
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
96
# DGL provides ``dgl.dataloading.as_edge_prediction_sampler`` to
97
98
99
100
101
102
103
104
105
106
107
108
109
# iterate over edges for edge classification or link prediction tasks.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
# ``dgl.dataloading.negative_sampler.Uniform``.  Here this tutorial uniformly
# draws 5 negative examples per positive example.
#

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)


######################################################################
# After defining the negative sampler, one can then define the edge data
110
# loader with neighbor sampling.  To create an ``DataLoader`` for
111
112
113
114
# link prediction, provide a neighbor sampler object as well as the negative
# sampler object created above.
#

115
116
sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
118
    sampler, negative_sampler=negative_sampler
)
119
120
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
    graph,  # The graph
122
    torch.arange(graph.number_of_edges()),  # The edges to iterate over
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
123
124
    sampler,  # The neighbor sampler
    device=device,  # Put the MFGs on CPU or GPU
125
    # The following arguments are inherited from PyTorch DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
126
127
128
129
    batch_size=1024,  # Batch size
    shuffle=True,  # Whether to shuffle the nodes for every epoch
    drop_last=False,  # Whether to drop the last incomplete batch
    num_workers=0,  # Number of sampler processes
130
131
132
133
134
135
136
137
)


######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#

138
input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
140
141
142
143
144
145
146
147
148
149
150
151
print("Number of input nodes:", len(input_nodes))
print(
    "Positive graph # nodes:",
    pos_graph.number_of_nodes(),
    "# edges:",
    pos_graph.number_of_edges(),
)
print(
    "Negative graph # nodes:",
    neg_graph.number_of_nodes(),
    "# edges:",
    neg_graph.number_of_edges(),
)
152
print(mfgs)
153
154
155
156
157
158
159
160
161
162
163


######################################################################
# The example minibatch consists of four elements.
#
# The first element is an ID tensor for the input nodes, i.e., nodes
# whose input features are needed on the first GNN layer for this minibatch.
#
# The second element and the third element are the positive graph and the
# negative graph for this minibatch.
# The concept of positive and negative graphs have been introduced in the
164
# :doc:`full-graph link prediction tutorial <../blitz/4_link_predict>`.  In minibatch
165
166
167
168
# training, the positive graph and the negative graph only contain nodes
# necessary for computing the pair-wise scores of positive and negative examples
# in the current minibatch.
#
169
170
171
# The last element is a list of :doc:`MFGs <L0_neighbor_sampling_overview>`
# storing the computation dependencies for each GNN layer.
# The MFGs are used to compute the GNN outputs of the nodes
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# involved in positive/negative graph.
#


######################################################################
# Defining Model for Node Representation
# --------------------------------------
#
# The model is almost identical to the one in the :doc:`node classification
# tutorial <L1_large_node_classification>`. The only difference is
# that since you are doing link prediction, the output dimension will not
# be the number of classes in the dataset.
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190

191
192
193
class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
194
195
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type="mean")
196
197
        self.h_feats = h_feats

198
    def forward(self, mfgs, x):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
199
        h_dst = x[: mfgs[0].num_dst_nodes()]
200
        h = self.conv1(mfgs[0], (x, h_dst))
201
        h = F.relu(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
202
        h_dst = h[: mfgs[1].num_dst_nodes()]
203
        h = self.conv2(mfgs[1], (h, h_dst))
204
205
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
206

207
model = Model(num_features, 128).to(device)
208
209
210
211
212
213
214
215
216
217
218


######################################################################
# Defining the Score Predictor for Edges
# --------------------------------------
#
# After getting the node representation necessary for the minibatch, the
# last thing to do is to predict the score of the edges and non-existent
# edges in the sampled minibatch.
#
# The following score predictor, copied from the :doc:`link prediction
219
# tutorial <../blitz/4_link_predict>`, takes a dot product between the
220
221
222
223
224
# incident nodes’ representations.
#

import dgl.function as fn

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
225

226
227
228
class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
229
            g.ndata["h"] = h
230
231
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
232
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
233
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
234
            return g.edata["score"][:, 0]
235
236
237


######################################################################
238
239
# Evaluating Performance with Unsupervised Learning (Optional)
# ------------------------------------------------------------
240
241
242
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
243
244
245
246
247
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__.
# Basically, it first trains a GNN via link prediction, and get an embedding
# for each node.  Then it trains a downstream classifier on top of this
# embedding and compute the accuracy as an assessment of the embedding
# quality.
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#


######################################################################
# To obtain the representations of all the nodes, this tutorial uses
# neighbor sampling as introduced in the :doc:`node classification
# tutorial <L1_large_node_classification>`.
#
# .. note::
#
#    If you would like to obtain node representations without
#    neighbor sampling during inference, please refer to this :ref:`user
#    guide <guide-minibatch-inference>`.
#

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
263

264
265
266
267
def inference(model, graph, node_features):
    with torch.no_grad():
        nodes = torch.arange(graph.number_of_nodes())

268
269
        sampler = dgl.dataloading.NeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
270
271
272
            graph,
            torch.arange(graph.number_of_nodes()),
            sampler,
273
274
275
276
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
277
278
            device=device,
        )
279
280

        result = []
281
        for input_nodes, output_nodes, mfgs in train_dataloader:
282
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
283
            inputs = mfgs[0].srcdata["feat"]
284
            result.append(model(mfgs, inputs))
285
286
287

        return torch.cat(result)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
288

289
290
import sklearn.metrics

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
291

292
def evaluate(emb, label, train_nids, valid_nids, test_nids):
293
    classifier = nn.Linear(emb.shape[1], num_classes).to(device)
294
295
296
    opt = torch.optim.LBFGS(classifier.parameters())

    def compute_loss():
297
298
        pred = classifier(emb[train_nids].to(device))
        loss = F.cross_entropy(pred, label[train_nids].to(device))
299
300
301
302
303
304
305
306
        return loss

    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
307
    prev_loss = float("inf")
308
309
310
311
312
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
313
                print("Converges at iteration", i)
314
315
316
317
318
                break
            else:
                prev_loss = loss

    with torch.no_grad():
319
        pred = classifier(emb.to(device)).cpu()
320
        label = label
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
321
322
323
324
325
326
        valid_acc = sklearn.metrics.accuracy_score(
            label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1)
        )
        test_acc = sklearn.metrics.accuracy_score(
            label[test_nids].numpy(), pred[test_nids].numpy().argmax(1)
        )
327
328
329
330
331
332
333
334
335
336
    return valid_acc, test_acc


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

337
338
model = Model(node_features.shape[1], 128).to(device)
predictor = DotPredictor().to(device)
339
340
341
342
343
344
345
346
347
348
349
350
351
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))


######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
# validation set:
#

import tqdm
import sklearn.metrics

best_accuracy = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
352
best_model_path = "model.pt"
353
354
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
355
        for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
356
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
357
            inputs = mfgs[0].srcdata["feat"]
358

359
            outputs = model(mfgs, inputs)
360
361
362
363
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)

            score = torch.cat([pos_score, neg_score])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
364
365
366
            label = torch.cat(
                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]
            )
367
368
369
370
371
372
            loss = F.binary_cross_entropy_with_logits(score, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
373
            tq.set_postfix({"loss": "%.03f" % loss.item()}, refresh=False)
374

375
            if (step + 1) % 500 == 0:
376
377
                model.eval()
                emb = inference(model, graph, node_features)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
378
379
380
381
382
383
384
385
                valid_acc, test_acc = evaluate(
                    emb, node_labels, train_nids, valid_nids, test_nids
                )
                print(
                    "Epoch {} Validation Accuracy {} Test Accuracy {}".format(
                        epoch, valid_acc, test_acc
                    )
                )
386
387
388
389
390
391
392
393
394
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

                # Note that this tutorial do not train the whole model to the end.
                break


395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
######################################################################
# Evaluating Performance with Link Prediction (Optional)
# ------------------------------------------------------
#
# In practice, it is more common to evaluate the link prediction
# model to see whether it can predict new edges. There are different
# evaluation metrics such as
# `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
# or `various metrics from information retrieval <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)>`__.
# Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs.
#
# Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src``
# and ``test_neg_dst`` are ground truth node pairs without edges
# in between (or *negative* pairs).
#

# Positive pairs
415
416
417
418
419
# These are randomly generated as an example.  You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
420
421
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
)
422
423
# Negative pairs.  Likewise, you will need to replace them with your
# own ground truth.
424
test_neg_src = test_pos_src
425
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
426
427
428


######################################################################
429
430
# First you need to compute the node representations for all the nodes
# with the ``inference`` method above:
431
432
#

433
node_reprs = inference(model, graph, node_features)
434
435

######################################################################
436
437
438
# Since the predictor is a dot product, you can now easily compute the
# score of positive and negative test pairs to compute metrics such
# as AUC:
439
440
#

441
442
443
444
445
446
447
h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
448
449
450
451
452
test_labels = (
    torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)])
    .cpu()
    .numpy()
)
453
454

auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
455
print("Link Prediction AUC:", auc)
456
457


458
459
460
461
462
463
464
465
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#

466

467
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
468
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'