L2_large_link_prediction.py 15.7 KB
Newer Older
1
2
3
4
5
"""
Stochastic Training of GNN for Link Prediction
==============================================

This tutorial will show how to train a multi-layer GraphSAGE for link
6
prediction on ``ogbn-arxiv`` provided by `Open Graph Benchmark
7
(OGB) <https://ogb.stanford.edu/>`__. The dataset
8
contains around 170 thousand nodes and 1 million edges.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

By the end of this tutorial, you will be able to

-  Train a GNN model for link prediction on a single GPU with DGL's
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.

"""


######################################################################
# Link Prediction Overview
# ------------------------
#
# Link prediction requires the model to predict the probability of
# existence of an edge. This tutorial does so by computing a dot product
# between the representations of both incident nodes.
#
# .. math::
#
#
#    \hat{y}_{u\sim v} = \sigma(h_u^T h_v)
#
# It then minimizes the following binary cross entropy loss.
#
# .. math::
#
#
#    \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
#
# This is identical to the link prediction formulation in :doc:`the previous
43
# tutorial on link prediction <../blitz/4_link_predict>`.
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#


######################################################################
# Loading Dataset
# ---------------
#
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#

import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

60
61
dataset = DglNodePropPredDataset('ogbn-arxiv')
device = 'cpu'      # change to 'cuda' for GPU
62
63

graph, node_labels = dataset[0]
64
65
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
86
# graph <../blitz/4_link_predict>`, a common practice to train GNN on large graphs is
87
88
89
90
91
92
93
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
94
# DGL provides ``dgl.dataloading.as_edge_prediction_sampler`` to
95
96
97
98
99
100
101
102
103
104
105
106
107
# iterate over edges for edge classification or link prediction tasks.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
# ``dgl.dataloading.negative_sampler.Uniform``.  Here this tutorial uniformly
# draws 5 negative examples per positive example.
#

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)


######################################################################
# After defining the negative sampler, one can then define the edge data
108
# loader with neighbor sampling.  To create an ``DataLoader`` for
109
110
111
112
# link prediction, provide a neighbor sampler object as well as the negative
# sampler object created above.
#

113
114
115
116
117
sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler, negative_sampler=negative_sampler)
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DataLoader.
118
119
120
    graph,                                  # The graph
    torch.arange(graph.number_of_edges()),  # The edges to iterate over
    sampler,                                # The neighbor sampler
121
    device=device,                          # Put the MFGs on CPU or GPU
122
123
124
125
126
127
128
129
130
131
132
133
134
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=1024,    # Batch size
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=False,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)


######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#

135
input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
136
137
138
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
139
print(mfgs)
140
141
142
143
144
145
146
147
148
149
150


######################################################################
# The example minibatch consists of four elements.
#
# The first element is an ID tensor for the input nodes, i.e., nodes
# whose input features are needed on the first GNN layer for this minibatch.
#
# The second element and the third element are the positive graph and the
# negative graph for this minibatch.
# The concept of positive and negative graphs have been introduced in the
151
# :doc:`full-graph link prediction tutorial <../blitz/4_link_predict>`.  In minibatch
152
153
154
155
# training, the positive graph and the negative graph only contain nodes
# necessary for computing the pair-wise scores of positive and negative examples
# in the current minibatch.
#
156
157
158
# The last element is a list of :doc:`MFGs <L0_neighbor_sampling_overview>`
# storing the computation dependencies for each GNN layer.
# The MFGs are used to compute the GNN outputs of the nodes
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# involved in positive/negative graph.
#


######################################################################
# Defining Model for Node Representation
# --------------------------------------
#
# The model is almost identical to the one in the :doc:`node classification
# tutorial <L1_large_node_classification>`. The only difference is
# that since you are doing link prediction, the output dimension will not
# be the number of classes in the dataset.
#

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats

184
185
186
    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
187
        h = F.relu(h)
188
189
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
190
191
        return h

192
model = Model(num_features, 128).to(device)
193
194
195
196
197
198
199
200
201
202
203


######################################################################
# Defining the Score Predictor for Edges
# --------------------------------------
#
# After getting the node representation necessary for the minibatch, the
# last thing to do is to predict the score of the edges and non-existent
# edges in the sampled minibatch.
#
# The following score predictor, copied from the :doc:`link prediction
204
# tutorial <../blitz/4_link_predict>`, takes a dot product between the
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# incident nodes’ representations.
#

import dgl.function as fn

class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
            g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
            return g.edata['score'][:, 0]


######################################################################
222
223
# Evaluating Performance with Unsupervised Learning (Optional)
# ------------------------------------------------------------
224
225
226
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
227
228
229
230
231
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__.
# Basically, it first trains a GNN via link prediction, and get an embedding
# for each node.  Then it trains a downstream classifier on top of this
# embedding and compute the accuracy as an assessment of the embedding
# quality.
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#


######################################################################
# To obtain the representations of all the nodes, this tutorial uses
# neighbor sampling as introduced in the :doc:`node classification
# tutorial <L1_large_node_classification>`.
#
# .. note::
#
#    If you would like to obtain node representations without
#    neighbor sampling during inference, please refer to this :ref:`user
#    guide <guide-minibatch-inference>`.
#

def inference(model, graph, node_features):
    with torch.no_grad():
        nodes = torch.arange(graph.number_of_nodes())

251
252
        sampler = dgl.dataloading.NeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.DataLoader(
253
254
255
256
257
            graph, torch.arange(graph.number_of_nodes()), sampler,
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4,
258
            device=device)
259
260

        result = []
261
        for input_nodes, output_nodes, mfgs in train_dataloader:
262
            # feature copy from CPU to GPU takes place here
263
264
            inputs = mfgs[0].srcdata['feat']
            result.append(model(mfgs, inputs))
265
266
267
268
269
270

        return torch.cat(result)

import sklearn.metrics

def evaluate(emb, label, train_nids, valid_nids, test_nids):
271
    classifier = nn.Linear(emb.shape[1], num_classes).to(device)
272
273
274
    opt = torch.optim.LBFGS(classifier.parameters())

    def compute_loss():
275
276
        pred = classifier(emb[train_nids].to(device))
        loss = F.cross_entropy(pred, label[train_nids].to(device))
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        return loss

    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss

    prev_loss = float('inf')
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
                print('Converges at iteration', i)
                break
            else:
                prev_loss = loss

    with torch.no_grad():
297
        pred = classifier(emb.to(device)).cpu()
298
299
300
301
302
303
304
305
306
307
308
309
310
        label = label
        valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1))
        test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1))
    return valid_acc, test_acc


######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#

311
312
model = Model(node_features.shape[1], 128).to(device)
predictor = DotPredictor().to(device)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))


######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
# validation set:
#

import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
329
        for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
330
            # feature copy from CPU to GPU takes place here
331
            inputs = mfgs[0].srcdata['feat']
332

333
            outputs = model(mfgs, inputs)
334
335
336
337
338
339
340
341
342
343
344
345
346
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)

            score = torch.cat([pos_score, neg_score])
            label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
            loss = F.binary_cross_entropy_with_logits(score, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

            tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)

347
            if (step + 1) % 500 == 0:
348
349
350
351
352
353
354
355
356
357
358
359
360
                model.eval()
                emb = inference(model, graph, node_features)
                valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids)
                print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

                # Note that this tutorial do not train the whole model to the end.
                break


361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
######################################################################
# Evaluating Performance with Link Prediction (Optional)
# ------------------------------------------------------
#
# In practice, it is more common to evaluate the link prediction
# model to see whether it can predict new edges. There are different
# evaluation metrics such as
# `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
# or `various metrics from information retrieval <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)>`__.
# Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs.
#
# Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src``
# and ``test_neg_dst`` are ground truth node pairs without edges
# in between (or *negative* pairs).
#

# Positive pairs
381
382
383
384
385
386
387
388
# These are randomly generated as an example.  You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
    torch.randint(0, graph.num_nodes(), (n_test_pos,)))
# Negative pairs.  Likewise, you will need to replace them with your
# own ground truth.
389
test_neg_src = test_pos_src
390
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
391
392
393


######################################################################
394
395
# First you need to compute the node representations for all the nodes
# with the ``inference`` method above:
396
397
#

398
node_reprs = inference(model, graph, node_features)
399
400

######################################################################
401
402
403
# Since the predictor is a dot product, you can now easily compute the
# score of positive and negative test pairs to compute metrics such
# as AUC:
404
405
#

406
407
408
409
410
411
412
413
h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
test_labels = torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)]).cpu().numpy()
414
415
416
417
418

auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
print('Link Prediction AUC:', auc)


419
420
421
422
423
424
425
426
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#

427

428
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
429
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'