link_pred.py 7.12 KB
Newer Older
1
import argparse
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
10
import tqdm
11
12
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang.
# (This is a long-standing issue)
13
from ogb.linkproppred import DglLinkPropPredDataset
14

15
16
17
18
19
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
                    help='Perform both sampling and training on GPU.')
args = parser.parse_args()

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
device = 'cuda'

def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
    is the reverse edge of edge ID ``i``.
    Does not work with graphs that have self-loops.
    """
    g_simple, mapping = dgl.to_simple(
        dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
    c = g_simple.edata['count']
    num_edges = g.num_edges()
    mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype)
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
    reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
    reverse_mapping = mapping[reverse_idx]

    # Correctness check
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
44
45

class SAGE(nn.Module):
46
    def __init__(self, in_feats, n_hidden):
47
        super().__init__()
48
        self.n_hidden = n_hidden
49
50
51
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
52
53
54
55
56
57
58
59
60
61
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.predictor = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1))

    def predict(self, h_src, h_dst):
        return self.predictor(h_src * h_dst)
62
63
64
65
66
67
68

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
69
70
71
72
73
74
75
76
77
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
        h_pos = self.predict(h[pos_src], h[pos_dst])
        h_neg = self.predict(h[neg_src], h[neg_dst])
        return h_pos, h_neg

    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
78
79
        feat = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
80
        dataloader = dgl.dataloading.DataLoader(
81
82
83
84
85
86
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
87
88
89
90
            y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device,
                            pin_memory=args.pure_gpu)
            feat = feat.to(device)

91
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
92
                x = feat[input_nodes]
93
94
95
96
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
97
            feat = y
98
99
100
101
102
103
104
105
106
107
108
        return y


def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
    rr = torch.zeros(src.shape[0])
    for start in tqdm.trange(0, src.shape[0], batch_size):
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
        pred = model.predict(h_src, h_dst).squeeze(-1)
109
        relevance = torch.zeros(*pred.shape, dtype=torch.bool).to(pred.device)
110
111
112
113
114
115
116
117
118
119
        relevance[:, 0] = True
        rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
    return rr.mean()


def evaluate(model, edge_split, device, num_workers):
    with torch.no_grad():
        node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
        results = []
        for split in ['valid', 'test']:
120
121
122
            src = edge_split[split]['source_node'].to(node_emb.device)
            dst = edge_split[split]['target_node'].to(node_emb.device)
            neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
123
124
125
126
127
128
129
            results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
    return results


dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
130
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
131
132
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
133
134
135
edge_split = dataset.get_edge_split()

model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
136
137
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

138
sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
139
140
141
142
sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, exclude='reverse_id', reverse_eids=reverse_eids,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader(
143
144
        graph, seed_edges, sampler,
        device=device, batch_size=512, shuffle=True,
145
        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
146
147

durations = []
148
149
for epoch in range(10):
    model.train()
150
151
152
153
154
155
156
157
158
159
160
161
    t0 = time.time()
    for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
        x = blocks[0].srcdata['feat']
        pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
        pos_label = torch.ones_like(pos_score)
        neg_label = torch.zeros_like(neg_score)
        score = torch.cat([pos_score, neg_score])
        labels = torch.cat([pos_label, neg_label])
        loss = F.binary_cross_entropy_with_logits(score, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
162
        if (it + 1) % 20 == 0:
163
            mem = torch.cuda.max_memory_allocated() / 1000000
164
165
166
167
168
169
170
171
            print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
            if (it + 1) == 1000:
                tt = time.time()
                print(tt - t0)
                durations.append(tt - t0)
                break
    if epoch % 10 == 0:
        model.eval()
172
        valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12)
173
        print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
174
print(np.mean(durations[4:]), np.std(durations[4:]))