L2_large_link_prediction.py 8.14 KB
Newer Older
1
2
3
4
5
"""
Stochastic Training of GNN for Link Prediction
==============================================

This tutorial will show how to train a multi-layer GraphSAGE for link
6
7
prediction on `CoraGraphDataset <https://data.dgl.ai/dataset/cora_v2.zip>`__.
The dataset contains 2708 nodes and 10556 edges.
8
9
10

By the end of this tutorial, you will be able to

11
-  Train a GNN model for link prediction on target device with DGL's
12
13
14
15
16
17
18
19
20
21
22
23
24
   neighbor sampling components.

This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.

"""


######################################################################
# Link Prediction Overview
# ------------------------
#
25
26
27
28
# Unlike node classification predicts labels for nodes based on their
# local neighborhoods, link prediction assesses the likelihood of an edge
# existing between two nodes, necessitating different sampling strategies
# that account for pairs of nodes and their joint neighborhoods.
29
30
31
32
33
34
35
#


######################################################################
# Loading Dataset
# ---------------
#
36
# `cora` is already prepared as ``BuiltinDataset`` in GraphBolt.
37
#
38

39
import os
40
41

os.environ["DGLBACKEND"] = "pytorch"
42
import dgl.graphbolt as gb
43
import numpy as np
44
import torch
45
import tqdm
46

47
48
dataset = gb.BuiltinDataset("cora").load()
device = torch.device("cpu")  # change to 'cuda' for GPU
49
50


51
52
53
54
55
56
57
######################################################################
# Dataset consists of graph, feature and tasks. You can get the
# training-validation-test set from the tasks. Seed nodes and corresponding
# labels are already stored in each training-validation-test set. This
# dataset contains 2 tasks, one for node classification and the other for
# link prediction. We will use the link prediction task.
#
58

59
60
61
62
63
64
graph = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
print(f"Task: {task_name}.")
65
66
67
68
69
70
71


######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
72
# graph <../blitz/4_link_predict>`, a common practice to train GNN on large graphs is
73
74
75
76
77
78
79
80
81
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
82
# ``dgl.graphbolt.UniformNegativeSampler``.  Here this tutorial uniformly
83
84
# draws 5 negative examples per positive example.
#
85
86
# Except for the negative sampler, the rest of the code is identical to
# the :doc:`node classification tutorial <L1_large_node_classification>`.
87
88
#

89
90
91
92
93
94
95
datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(graph, [5, 5, 5])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
train_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
96
97
98
99
100
101
102


######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#

103
104
data = next(iter(train_dataloader))
print(f"DGLMiniBatch: {data}")
105
106
107
108
109
110
111


######################################################################
# Defining Model for Node Representation
# --------------------------------------
#

112
import dgl.nn as dglnn
113
114
115
import torch.nn as nn
import torch.nn.functional as F

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116

117
118
119
120
121
122
123
124
125
126
127
128
class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129

130
131
132
133
134
135
136
137
    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x
138
139
140


######################################################################
141
142
# Defining Training Loop
# ----------------------
143
#
144
# The following initializes the model and defines the optimizer.
145
146
#

147
148
149
in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
150
151


152
153
#####################################################################
# Convert the minibatch to a training pair and a label tensor.
154
155
156
#


157
158
159
160
161
162
163
164
165
166
167
168
def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
    """Convert the minibatch to a training pair and a label tensor."""
    pos_src, pos_dst = data.positive_node_pairs
    neg_src, neg_dst = data.negative_node_pairs
    node_pairs = (
        torch.cat((pos_src, neg_src), dim=0),
        torch.cat((pos_dst, neg_dst), dim=0),
    )
    pos_label = torch.ones_like(pos_src)
    neg_label = torch.zeros_like(neg_src)
    labels = torch.cat([pos_label, neg_label], dim=0)
    return (node_pairs, labels.float())
169
170
171


######################################################################
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# The following is the training loop for link prediction and
# evaluation.
#

for epoch in range(10):
    model.train()
    total_loss = 0
    for step, data in tqdm.tqdm(enumerate(train_dataloader)):
        # Unpack MiniBatch.
        compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_pairs[0]] * y[compacted_pairs[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
197

198
        total_loss += loss.item()
199

200
    print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")
201

202

203
######################################################################
204
205
# Evaluating Performance with Link Prediction
# -------------------------------------------
206
207
208
#


209
model.eval()
210

211
212
213
214
215
216
217
218
datapipe = gb.ItemSampler(test_set, batch_size=256, shuffle=False)
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
datapipe = datapipe.sample_neighbor(graph, [-1, -1])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
eval_dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
219

220
221
222
223
224
logits = []
labels = []
for step, data in enumerate(eval_dataloader):
    # Unpack MiniBatch.
    compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
225

226
227
    # The features of sampled nodes.
    x = data.node_features["feat"]
228

229
230
231
232
233
234
235
    # Forward.
    y = model(data.blocks, x)
    logit = (
        model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])
        .squeeze()
        .detach()
    )
236

237
238
    logits.append(logit)
    labels.append(label)
239

240
241
logits = torch.cat(logits, dim=0)
labels = torch.cat(labels, dim=0)
242
243


244
245
246
247
# Compute the AUROC score.
from sklearn.metrics import roc_auc_score

auc = roc_auc_score(labels, logits)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
248
print("Link Prediction AUC:", auc)
249
250


251
252
253
254
255
256
257
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#