L2_large_link_prediction.py 15.9 KB
Newer Older
1
2
3
4
5
"""
Stochastic Training of GNN for Link Prediction
==============================================

This tutorial will show how to train a multi-layer GraphSAGE for link
6
prediction on ``ogbn-arxiv`` provided by `Open Graph Benchmark
7
(OGB) <https://ogb.stanford.edu/>`__. The dataset
8
contains around 170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

By the end of this tutorial, you will be able to

-  Train a GNN model for link prediction on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.

"""


######################################################################
# Link Prediction Overview
# ------------------------
#
# Link prediction requires the model to predict the probability of
# existence of an edge. This tutorial does so by computing a dot product
# between the representations of both incident nodes.
#
# .. math::
#
#
#    \hat{y}_{u\sim v} = \sigma(h_u^T h_v)
#
# It then minimizes the following binary cross entropy loss.
#
# .. math::
#
#
#    \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
#
# This is identical to the link prediction formulation in :doc:`the previous
43
# tutorial on link prediction <../blitz/4_link_predict>`.
44
45
46
47
48
49
50
51
52
53
54
#


######################################################################
# Loading Dataset
# ---------------
#
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#

55
import os
56
57

os.environ["DGLBACKEND"] = "pytorch"
58
59
import dgl
import numpy as np
60
import torch
61
62
from ogb.nodeproppred import DglNodePropPredDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
64
dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu"  # change to 'cuda' for GPU
65
66

graph, node_labels = dataset[0]
67
68
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
69
70
71
print(graph)
print(node_labels)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
72
node_features = graph.ndata["feat"]
73
74
75
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
76
print("Number of classes:", num_classes)
77
78

idx_split = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
79
80
81
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]
82
83
84
85
86
87
88


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
89
# graph <../blitz/4_link_predict>`, a common practice to train GNN on large graphs is
90
91
92
93
94
95
96
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
97
# DGL provides ``dgl.dataloading.as_edge_prediction_sampler`` to
98
99
100
101
102
103
104
105
106
107
108
109
110
# iterate over edges for edge classification or link prediction tasks.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
# ``dgl.dataloading.negative_sampler.Uniform``.  Here this tutorial uniformly
# draws 5 negative examples per positive example.
#

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)


######################################################################
# After defining the negative sampler, one can then define the edge data
111
# loader with neighbor sampling.  To create an ``DataLoader`` for
112
113
114
115
# link prediction, provide a neighbor sampler object as well as the negative
# sampler object created above.
#

116
117
sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
118
119
    sampler, negative_sampler=negative_sampler
)
120
121
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
122
    graph,  # The graph
123
    torch.arange(graph.num_edges()),  # The edges to iterate over
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
124
125
    sampler,  # The neighbor sampler
    device=device,  # Put the MFGs on CPU or GPU
126
    # The following arguments are inherited from PyTorch DataLoader.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
128
129
130
    batch_size=1024,  # Batch size
    shuffle=True,  # Whether to shuffle the nodes for every epoch
    drop_last=False,  # Whether to drop the last incomplete batch
    num_workers=0,  # Number of sampler processes
131
132
133
134
135
136
137
138
)


######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#

139
input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
140
141
142
print("Number of input nodes:", len(input_nodes))
print(
    "Positive graph # nodes:",
143
    pos_graph.num_nodes(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
144
    "# edges:",
145
    pos_graph.num_edges(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
148
)
print(
    "Negative graph # nodes:",
149
    neg_graph.num_nodes(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
    "# edges:",
151
    neg_graph.num_edges(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
152
)
153
print(mfgs)
154
155
156
157
158
159
160
161
162
163
164


######################################################################
# The example minibatch consists of four elements.
#
# The first element is an ID tensor for the input nodes, i.e., nodes
# whose input features are needed on the first GNN layer for this minibatch.
#
# The second element and the third element are the positive graph and the
# negative graph for this minibatch.
# The concept of positive and negative graphs have been introduced in the
165
# :doc:`full-graph link prediction tutorial <../blitz/4_link_predict>`.  In minibatch
166
167
168
169
# training, the positive graph and the negative graph only contain nodes
# necessary for computing the pair-wise scores of positive and negative examples
# in the current minibatch.
#
170
171
172
# The last element is a list of :doc:`MFGs <L0_neighbor_sampling_overview>`
# storing the computation dependencies for each GNN layer.
# The MFGs are used to compute the GNN outputs of the nodes
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# involved in positive/negative graph.
#


######################################################################
# Defining Model for Node Representation
# --------------------------------------
#
# The model is almost identical to the one in the :doc:`node classification
# tutorial <L1_large_node_classification>`. The only difference is
# that since you are doing link prediction, the output dimension will not
# be the number of classes in the dataset.
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191

192
193
194
class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
196
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type="mean")
197
198
        self.h_feats = h_feats

199
    def forward(self, mfgs, x):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
200
        h_dst = x[: mfgs[0].num_dst_nodes()]
201
        h = self.conv1(mfgs[0], (x, h_dst))
202
        h = F.relu(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
203
        h_dst = h[: mfgs[1].num_dst_nodes()]
204
        h = self.conv2(mfgs[1], (h, h_dst))
205
206
        return h

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
207

208
model = Model(num_features, 128).to(device)
209
210
211
212
213
214
215
216
217
218
219


######################################################################
# Defining the Score Predictor for Edges
# --------------------------------------
#
# After getting the node representation necessary for the minibatch, the
# last thing to do is to predict the score of the edges and non-existent
# edges in the sampled minibatch.
#
# The following score predictor, copied from the :doc:`link prediction
220
# tutorial <../blitz/4_link_predict>`, takes a dot product between the
221
222
223
224
225
# incident nodes’ representations.
#

import dgl.function as fn

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
226

227
228
229
class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
230
            g.ndata["h"] = h
231
232
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
233
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
234
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
235
            return g.edata["score"][:, 0]
236
237
238


######################################################################
239
240
# Evaluating Performance with Unsupervised Learning (Optional)
# ------------------------------------------------------------
241
242
243
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
244
245
246
247
248
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__.
# Basically, it first trains a GNN via link prediction, and get an embedding
# for each node.  Then it trains a downstream classifier on top of this
# embedding and compute the accuracy as an assessment of the embedding
# quality.
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
#


######################################################################
# To obtain the representations of all the nodes, this tutorial uses
# neighbor sampling as introduced in the :doc:`node classification
# tutorial <L1_large_node_classification>`.
#
# .. note::
#
#    If you would like to obtain node representations without
#    neighbor sampling during inference, please refer to this :ref:`user
#    guide <guide-minibatch-inference>`.
#

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
264

265
266
def inference(model, graph, node_features):
    with torch.no_grad():
267
        nodes = torch.arange(graph.num_nodes())
268

269
270
        sampler = dgl.dataloading.NeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
271
            graph,
272
            torch.arange(graph.num_nodes()),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
273
            sampler,
274
275
276
277
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
278
279
            device=device,
        )
280
281

        result = []
282
        for input_nodes, output_nodes, mfgs in train_dataloader:
283
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
284
            inputs = mfgs[0].srcdata["feat"]
285
            result.append(model(mfgs, inputs))
286
287
288

        return torch.cat(result)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
289

290
291
import sklearn.metrics

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
292

293
def evaluate(emb, label, train_nids, valid_nids, test_nids):
294
    classifier = nn.Linear(emb.shape[1], num_classes).to(device)
295
296
297
    opt = torch.optim.LBFGS(classifier.parameters())

    def compute_loss():
298
299
        pred = classifier(emb[train_nids].to(device))
        loss = F.cross_entropy(pred, label[train_nids].to(device))
300
301
302
303
304
305
306
307
        return loss

    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
308
    prev_loss = float("inf")
309
310
311
312
313
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
314
                print("Converges at iteration", i)
315
316
317
318
319
                break
            else:
                prev_loss = loss

    with torch.no_grad():
320
        pred = classifier(emb.to(device)).cpu()
321
        label = label
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
322
323
324
325
326
327
        valid_acc = sklearn.metrics.accuracy_score(
            label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1)
        )
        test_acc = sklearn.metrics.accuracy_score(
            label[test_nids].numpy(), pred[test_nids].numpy().argmax(1)
        )
328
329
330
331
332
333
334
335
336
337
    return valid_acc, test_acc


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

338
339
model = Model(node_features.shape[1], 128).to(device)
predictor = DotPredictor().to(device)
340
341
342
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))


343
344
import sklearn.metrics

345
346
347
348
349
350
351
352
353
######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
# validation set:
#

import tqdm

best_accuracy = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
354
best_model_path = "model.pt"
355
356
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
357
        for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
358
            # feature copy from CPU to GPU takes place here
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
359
            inputs = mfgs[0].srcdata["feat"]
360

361
            outputs = model(mfgs, inputs)
362
363
364
365
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)

            score = torch.cat([pos_score, neg_score])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
366
367
368
            label = torch.cat(
                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]
            )
369
370
371
372
373
374
            loss = F.binary_cross_entropy_with_logits(score, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
375
            tq.set_postfix({"loss": "%.03f" % loss.item()}, refresh=False)
376

377
            if (step + 1) % 500 == 0:
378
379
                model.eval()
                emb = inference(model, graph, node_features)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
380
381
382
383
384
385
386
387
                valid_acc, test_acc = evaluate(
                    emb, node_labels, train_nids, valid_nids, test_nids
                )
                print(
                    "Epoch {} Validation Accuracy {} Test Accuracy {}".format(
                        epoch, valid_acc, test_acc
                    )
                )
388
389
390
391
392
393
394
395
396
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

                # Note that this tutorial do not train the whole model to the end.
                break


397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
######################################################################
# Evaluating Performance with Link Prediction (Optional)
# ------------------------------------------------------
#
# In practice, it is more common to evaluate the link prediction
# model to see whether it can predict new edges. There are different
# evaluation metrics such as
# `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
# or `various metrics from information retrieval <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)>`__.
# Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs.
#
# Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src``
# and ``test_neg_dst`` are ground truth node pairs without edges
# in between (or *negative* pairs).
#

# Positive pairs
417
418
419
420
421
# These are randomly generated as an example.  You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
422
423
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
)
424
425
# Negative pairs.  Likewise, you will need to replace them with your
# own ground truth.
426
test_neg_src = test_pos_src
427
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
428
429
430


######################################################################
431
432
# First you need to compute the node representations for all the nodes
# with the ``inference`` method above:
433
434
#

435
node_reprs = inference(model, graph, node_features)
436
437

######################################################################
438
439
440
# Since the predictor is a dot product, you can now easily compute the
# score of positive and negative test pairs to compute metrics such
# as AUC:
441
442
#

443
444
445
446
447
448
449
h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
450
451
452
453
454
test_labels = (
    torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)])
    .cpu()
    .numpy()
)
455
456

auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
457
print("Link Prediction AUC:", auc)
458
459


460
461
462
463
464
465
466
467
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#

468

469
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
470
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'