link_pred.py 7.94 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
import argparse

import dgl
import dgl.nn as dglnn
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
9
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
12
13
14
15
16
from dgl.dataloading import (
    as_edge_prediction_sampler,
    DataLoader,
    MultiLayerFullNeighborSampler,
    negative_sampler,
    NeighborSampler,
)
17
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
18

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19

20
21
def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
22
    is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
23
24
    """
    g_simple, mapping = dgl.to_simple(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
25
26
27
        dgl.add_reverse_edges(g), return_counts="count", writeback_mapping=True
    )
    c = g_simple.edata["count"]
28
    num_edges = g.num_edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29
30
31
    mapping_offset = torch.zeros(
        g_simple.num_edges() + 1, dtype=g_simple.idtype
    )
32
33
34
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
35
36
37
    reverse_idx = torch.where(
        idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges
    )
38
    reverse_mapping = mapping[reverse_idx]
39
    # sanity check
40
41
42
43
44
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
45

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
46

47
class SAGE(nn.Module):
48
    def __init__(self, in_size, hid_size):
49
50
        super().__init__()
        self.layers = nn.ModuleList()
51
        # three-layer GraphSAGE-mean
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
53
54
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
55
        self.hid_size = hid_size
56
        self.predictor = nn.Sequential(
57
            nn.Linear(hid_size, hid_size),
58
            nn.ReLU(),
59
            nn.Linear(hid_size, hid_size),
60
            nn.ReLU(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
61
62
            nn.Linear(hid_size, 1),
        )
63
64
65
66
67
68
69

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
70
71
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
72
73
        h_pos = self.predictor(h[pos_src] * h[pos_dst])
        h_neg = self.predictor(h[neg_src] * h[neg_dst])
74
75
        return h_pos, h_neg

76
77
    def inference(self, g, device, batch_size):
        """Layer-wise inference algorithm to compute GNN node embeddings."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
78
79
        feat = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
80
        dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
81
82
83
84
85
86
87
88
89
90
91
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        pin_memory = buffer_device != device
92
        for l, layer in enumerate(self.layers):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
94
95
96
97
98
            y = torch.empty(
                g.num_nodes(),
                self.hid_size,
                device=buffer_device,
                pin_memory=pin_memory,
            )
99
            feat = feat.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
100
101
102
            for input_nodes, output_nodes, blocks in tqdm.tqdm(
                dataloader, desc="Inference"
            ):
103
                x = feat[input_nodes]
104
105
106
107
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
108
            feat = y
109
110
        return y

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
114

def compute_mrr(
    model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500
):
115
    """Compute Mean Reciprocal Rank (MRR) in batches."""
116
    rr = torch.zeros(src.shape[0])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
    for start in tqdm.trange(0, src.shape[0], batch_size, desc="Evaluate"):
118
119
120
121
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
122
123
124
        pred = model.predictor(h_src * h_dst).squeeze(-1)
        input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
        rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
125
126
    return rr.mean()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127

128
129
def evaluate(device, graph, edge_split, model, batch_size):
    model.eval()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
    evaluator = Evaluator(name="ogbl-citation2")
131
    with torch.no_grad():
132
        node_emb = model.inference(graph, device, batch_size)
133
        results = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
135
136
137
138
139
140
141
142
        for split in ["valid", "test"]:
            src = edge_split[split]["source_node"].to(node_emb.device)
            dst = edge_split[split]["target_node"].to(node_emb.device)
            neg_dst = edge_split[split]["target_node_neg"].to(node_emb.device)
            results.append(
                compute_mrr(
                    model, evaluator, node_emb, src, dst, neg_dst, device
                )
            )
143
144
    return results

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
145

146
147
def train(args, device, g, reverse_eids, seed_edges, model):
    # create sampler & dataloader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
    sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=["feat"])
149
    sampler = as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
150
151
152
153
154
155
        sampler,
        exclude="reverse_id",
        reverse_eids=reverse_eids,
        negative_sampler=negative_sampler.Uniform(1),
    )
    use_uva = args.mode == "mixed"
156
    dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
158
159
160
161
162
163
164
165
166
        g,
        seed_edges,
        sampler,
        device=device,
        batch_size=512,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )
167
168
169
170
    opt = torch.optim.Adam(model.parameters(), lr=0.0005)
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
171
172
173
174
        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(
            dataloader
        ):
            x = blocks[0].srcdata["feat"]
175
176
177
178
179
180
181
182
183
184
            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
            score = torch.cat([pos_score, neg_score])
            pos_label = torch.ones_like(pos_score)
            neg_label = torch.zeros_like(neg_score)
            labels = torch.cat([pos_label, neg_label])
            loss = F.binary_cross_entropy_with_logits(score, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
185
186
187
188
            if (it + 1) == 1000:
                break
        print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it + 1)))

189

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
if __name__ == "__main__":
191
    parser = argparse.ArgumentParser()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
193
194
195
196
197
198
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["cpu", "mixed", "puregpu"],
        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
199
200
    args = parser.parse_args()
    if not torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
201
202
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")
203
204

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
205
206
    print("Loading data")
    dataset = DglLinkPropPredDataset("ogbl-citation2")
207
    g = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
208
209
    g = g.to("cuda" if args.mode == "puregpu" else "cpu")
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
210
211
212
213
214
215
    g, reverse_eids = to_bidirected_with_reverse_mapping(g)
    reverse_eids = reverse_eids.to(device)
    seed_edges = torch.arange(g.num_edges()).to(device)
    edge_split = dataset.get_edge_split()

    # create GraphSAGE model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
216
    in_size = g.ndata["feat"].shape[1]
217
218
219
    model = SAGE(in_size, 256).to(device)

    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
220
    print("Training...")
221
222
223
    train(args, device, g, reverse_eids, seed_edges, model)

    # validate/test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
224
225
226
227
228
229
230
231
232
    print("Validation/Testing...")
    valid_mrr, test_mrr = evaluate(
        device, g, edge_split, model, batch_size=1000
    )
    print(
        "Validation MRR {:.4f}, Test MRR {:.4f}".format(
            valid_mrr.item(), test_mrr.item()
        )
    )