4_link_predict.py 12.1 KB
Newer Older
1
2
3
4
"""
Link Prediction using Graph Neural Networks
===========================================

5
6
7
8
9
In the :doc:`introduction <1_introduction>`, you have already learned
the basic workflow of using GNNs for node classification,
i.e. predicting the category of a node in a graph. This tutorial will
teach you how to train a GNN for link prediction, i.e. predicting the
existence of an edge between two arbitrary nodes in a graph.
10
11
12
13
14
15

By the end of this tutorial you will be able to

-  Build a GNN-based link prediction model.
-  Train and evaluate the model on a small DGL-provided dataset.

16
17
(Time estimate: 28 minutes)

18
19
20
"""

import itertools
21
22
import os
os.environ['DGLBACKEND'] = 'pytorch'
23

24
25
import numpy as np
import scipy.sparse as sp
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
29

30
31
import dgl
import dgl.data
32
33
34
35

######################################################################
# Overview of Link Prediction with GNN
# ------------------------------------
36
#
37
38
39
40
41
42
# Many applications such as social recommendation, item recommendation,
# knowledge graph completion, etc., can be formulated as link prediction,
# which predicts whether an edge exists between two particular nodes. This
# tutorial shows an example of predicting whether a citation relationship,
# either citing or being cited, between two papers exists in a citation
# network.
43
#
44
# This tutorial formulates the link prediction problem as a binary classification
45
# problem as follows:
46
#
47
48
49
50
51
52
53
# -  Treat the edges in the graph as *positive examples*.
# -  Sample a number of non-existent edges (i.e. node pairs with no edges
#    between them) as *negative* examples.
# -  Divide the positive examples and negative examples into a training
#    set and a test set.
# -  Evaluate the model with any binary classification metric such as Area
#    Under Curve (AUC).
54
#
55
56
57
58
59
60
# .. note::
#
#    The practice comes from
#    `SEAL <https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf>`__,
#    although the model here does not use their idea of node labeling.
#
61
62
63
64
65
# In some domains such as large-scale recommender systems or information
# retrieval, you may favor metrics that emphasize good performance of
# top-K predictions. In these cases you may want to consider other metrics
# such as mean average precision, and use other negative sampling methods,
# which are beyond the scope of this tutorial.
66
#
67
68
# Loading graph and features
# --------------------------
69
70
71
72
#
# Following the :doc:`introduction <1_introduction>`, this tutorial
# first loads the Cora dataset.
#
73
74
75
76
77
78
79


dataset = dgl.data.CoraGraphDataset()
g = dataset[0]


######################################################################
80
81
82
# Prepare training and testing sets
# ---------------------------------
#
83
84
85
# This tutorial randomly picks 10% of the edges for positive examples in
# the test set, and leave the rest for the training set. It then samples
# the same number of edges for negative examples in both sets.
86
#
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

# Split edge set for training and testing
u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1)
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

103
neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
104
105
106
107
108
109
110
111
test_neg_u, test_neg_v = (
    neg_u[neg_eids[:test_size]],
    neg_v[neg_eids[:test_size]],
)
train_neg_u, train_neg_v = (
    neg_u[neg_eids[test_size:]],
    neg_v[neg_eids[test_size:]],
)
112
113
114
115
116
117
118
119


######################################################################
# When training, you will need to remove the edges in the test set from
# the original graph. You can do this via ``dgl.remove_edges``.
#
# .. note::
#
120
121
122
#    ``dgl.remove_edges`` works by creating a subgraph from the
#    original graph, resulting in a copy and therefore could be slow for
#    large graphs. If so, you could save the training and test graph to
123
#    disk, as you would do for preprocessing.
124
#
125
126
127
128
129

train_g = dgl.remove_edges(g, eids[:test_size])


######################################################################
130
131
132
# Define a GraphSAGE model
# ------------------------
#
133
134
135
136
# This tutorial builds a model consisting of two
# `GraphSAGE <https://arxiv.org/abs/1706.02216>`__ layers, each computes
# new node representations by averaging neighbor information. DGL provides
# ``dgl.nn.SAGEConv`` that conveniently creates a GraphSAGE layer.
137
#
138
139
140

from dgl.nn import SAGEConv

141

142
143
144
145
146
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
147
148
149
        self.conv1 = SAGEConv(in_feats, h_feats, "mean")
        self.conv2 = SAGEConv(h_feats, h_feats, "mean")

150
151
152
153
154
155
156
157
158
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


######################################################################
# The model then predicts the probability of existence of an edge by
159
160
161
162
# computing a score between the representations of both incident nodes
# with a function (e.g. an MLP or a dot product), which you will see in
# the next section.
#
163
# .. math::
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#
#
#    \hat{y}_{u\sim v} = f(h_u, h_v)
#


######################################################################
# Positive graph, negative graph, and ``apply_edges``
# ---------------------------------------------------
#
# In previous tutorials you have learned how to compute node
# representations with a GNN. However, link prediction requires you to
# compute representation of *pairs of nodes*.
#
# DGL recommends you to treat the pairs of nodes as another graph, since
# you can describe a pair of nodes with an edge. In link prediction, you
# will have a *positive graph* consisting of all the positive examples as
# edges, and a *negative graph* consisting of all the negative examples.
# The *positive graph* and the *negative graph* will contain the same set
# of nodes as the original graph.  This makes it easier to pass node
# features among multiple graphs for computation.  As you will see later,
185
# you can directly feed the node representations computed on the entire
186
187
188
189
190
191
192
# graph to the positive and the negative graphs for computing pair-wise
# scores.
#
# The following code constructs the positive graph and the negative graph
# for the training set and the test set respectively.
#

193
194
195
196
197
198
train_pos_g = dgl.graph(
    (train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()
)
train_neg_g = dgl.graph(
    (train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()
)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())


######################################################################
# The benefit of treating the pairs of nodes as a graph is that you can
# use the ``DGLGraph.apply_edges`` method, which conveniently computes new
# edge features based on the incident nodes’ features and the original
# edge features (if applicable).
#
# DGL provides a set of optimized builtin functions to compute new
# edge features based on the original node/edge features. For example,
# ``dgl.function.u_dot_v`` computes a dot product of the incident nodes’
# representations for each edge.
#

import dgl.function as fn

218

219
220
221
class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
222
            g.ndata["h"] = h
223
224
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
225
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
226
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
227
            return g.edata["score"][:, 0]
228
229
230
231
232
233
234
235


######################################################################
# You can also write your own function if it is complex.
# For instance, the following module produces a scalar score on each edge
# by concatenating the incident nodes’ features and passing it to an MLP.
#

236

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)

    def apply_edges(self, edges):
        """
        Computes a scalar score for each edge of the given graph.

        Parameters
        ----------
        edges :
            Has three members ``src``, ``dst`` and ``data``, each of
            which is a dictionary representing the features of the
            source nodes, the destination nodes, and the edges
            themselves.

        Returns
        -------
        dict
            A dictionary of new edge features.
        """
260
261
        h = torch.cat([edges.src["h"], edges.dst["h"]], 1)
        return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)}
262
263
264

    def forward(self, g, h):
        with g.local_scope():
265
            g.ndata["h"] = h
266
            g.apply_edges(self.apply_edges)
267
            return g.edata["score"]
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292


######################################################################
# .. note::
#
#    The builtin functions are optimized for both speed and memory.
#    We recommend using builtin functions whenever possible.
#
# .. note::
#
#    If you have read the :doc:`message passing
#    tutorial <3_message_passing>`, you will notice that the
#    argument ``apply_edges`` takes has exactly the same form as a message
#    function in ``update_all``.
#


######################################################################
# Training loop
# -------------
#
# After you defined the node representation computation and the edge score
# computation, you can go ahead and define the overall model, loss
# function, and evaluation metric.
#
293
# The loss function is simply binary cross entropy loss.
294
#
295
# .. math::
296
297
#
#
298
#    \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
299
300
301
302
#
# The evaluation metric in this tutorial is AUC.
#

303
model = GraphSAGE(train_g.ndata["feat"].shape[1], 16)
304
# You can replace DotPredictor with MLPPredictor.
305
# pred = MLPPredictor(16)
306
307
pred = DotPredictor()

308

309
310
def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
311
312
313
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
314
315
    return F.binary_cross_entropy_with_logits(scores, labels)

316

317
318
319
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
320
321
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
322
323
324
325
326
327
    return roc_auc_score(labels, scores)


######################################################################
# The training loop goes as follows:
#
328
# .. note::
329
#
330
331
332
#    This tutorial does not include evaluation on a validation
#    set. In practice you should save and evaluate the best model based on
#    performance on the validation set.
333
#
334
335
336

# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
337
338
339
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)
340
341

# ----------- 4. training -------------------------------- #
342
all_logits = []
343
344
for e in range(100):
    # forward
345
    h = model(train_g, train_g.ndata["feat"])
346
347
348
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)
349

350
351
352
353
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
354

355
    if e % 5 == 0:
356
        print("In epoch {}, loss: {}".format(e, loss))
357
358
359

# ----------- 5. check results ------------------------ #
from sklearn.metrics import roc_auc_score
360

361
with torch.no_grad():
362
363
    pos_score = pred(test_pos_g, h)
    neg_score = pred(test_neg_g, h)
364
    print("AUC", compute_auc(pos_score, neg_score))
365

366

367
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
368
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'