"src/vscode:/vscode.git/clone" did not exist on "6f28e1adb70e3054eacd5b02d459fc11c572128a"
link_prediction.py 5.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
This example shows how to create a GraphBolt dataloader to sample and train a
link prediction model with the Cora dataset.

Disclaimer: Please note that the test edges are not excluded from the original
graph in the dataset, which could lead to data leakage. We are ignoring this
issue for this example because we are focused on demonstrating usability.
"""

import dgl.graphbolt as gb
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
from torcheval.metrics import BinaryAUROC


############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
21
def create_dataloader(dataset, device, is_train=True):
22
23
24
25
26
27
28
    # The second of two tasks in the dataset is link prediction.
    task = dataset.tasks[1]
    itemset = task.train_set if is_train else task.test_set

    # Sample seed edges from the itemset.
    datapipe = gb.ItemSampler(itemset, batch_size=256)

29
30
31
    # Copy the mini-batch to the designated device for sampling and training.
    datapipe = datapipe.copy_to(device)

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    if is_train:
        # Sample negative edges for the seed edges.
        datapipe = datapipe.sample_uniform_negative(
            dataset.graph, negative_ratio=1
        )

        # Sample neighbors for the seed nodes.
        datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])

        # Exclude seed edges from the subgraph.
        datapipe = datapipe.transform(gb.exclude_seed_edges)

    else:
        # Sample neighbors for the seed nodes.
        datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[-1, -1])

    # Fetch features for sampled nodes.
    datapipe = datapipe.fetch_feature(
        dataset.feature, node_feature_keys=["feat"]
    )

    # Initiate the dataloader for the datapipe.
54
    return gb.DataLoader(datapipe)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86


class GraphSAGE(nn.Module):
    def __init__(self, in_size, hidden_size=16):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(SAGEConv(hidden_size, hidden_size, "mean"))
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x


@torch.no_grad()
def evaluate(model, dataset, device):
    model.eval()
    dataloader = create_dataloader(dataset, device, is_train=False)

    logits = []
    labels = []
    for step, data in enumerate(dataloader):
87
        # Get node pairs with labels for loss calculation.
88
89
        compacted_seeds = data.compacted_seeds.T
        label = data.labels
90
91
92
93
94
95
96

        # The features of sampled nodes.
        x = data.node_features["feat"]

        # Forward.
        y = model(data.blocks, x)
        logit = (
97
            model.predictor(
98
                y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]
99
            )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            .squeeze()
            .detach()
        )

        logits.append(logit)
        labels.append(label)

    logits = torch.cat(logits, dim=0)
    labels = torch.cat(labels, dim=0)

    # Compute the AUROC score.
    metric = BinaryAUROC()
    metric.update(logits, labels)
    score = metric.compute().item()
    print(f"AUC: {score:.3f}")


def train(model, dataset, device):
    dataloader = create_dataloader(dataset, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

    for epoch in range(10):
        model.train()
        total_loss = 0
        ########################################################################
        # (HIGHLIGHT) Iterate over the dataloader and train the model with all
        # mini-batches.
        ########################################################################
        for step, data in enumerate(dataloader):
129
            # Get node pairs with labels for loss calculation.
130
131
            compacted_seeds = data.compacted_seeds.T
            labels = data.labels
132
133
134
135
136
137
138

            # The features of sampled nodes.
            x = data.node_features["feat"]

            # Forward.
            y = model(data.blocks, x)
            logits = model.predictor(
139
                y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels.float())

            # Backward.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Training in {device} mode.")

    # Load and preprocess dataset.
    print("Loading data...")
161
    dataset = gb.BuiltinDataset("cora").load()
162

163
164
165
166
167
168
    # If a CUDA device is selected, we pin the graph and the features so that
    # the GPU can access them.
    if device == torch.device("cuda:0"):
        dataset.graph.pin_memory_()
        dataset.feature.pin_memory_()

169
170
171
172
173
174
175
176
177
178
    in_size = dataset.feature.size("node", None, "feat")[0]
    model = GraphSAGE(in_size).to(device)

    # Model training.
    print("Training...")
    train(model, dataset, device)

    # Test the model.
    print("Testing...")
    evaluate(model, dataset, device)