link_pred.py 7.08 KB
Newer Older
1
import argparse
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
10
import tqdm
11
12
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang.
# (This is a long-standing issue)
13
from ogb.linkproppred import DglLinkPropPredDataset
14

15
16
17
18
19
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
                    help='Perform both sampling and training on GPU.')
args = parser.parse_args()

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
device = 'cuda'

def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
    is the reverse edge of edge ID ``i``.
    Does not work with graphs that have self-loops.
    """
    g_simple, mapping = dgl.to_simple(
        dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
    c = g_simple.edata['count']
    num_edges = g.num_edges()
    mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype)
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
    reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
    reverse_mapping = mapping[reverse_idx]

    # Correctness check
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
44
45

class SAGE(nn.Module):
46
    def __init__(self, in_feats, n_hidden):
47
        super().__init__()
48
        self.n_hidden = n_hidden
49
50
51
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
52
53
54
55
56
57
58
59
60
61
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.predictor = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1))

    def predict(self, h_src, h_dst):
        return self.predictor(h_src * h_dst)
62
63
64
65
66
67
68

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
69
70
71
72
73
74
75
76
77
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
        h_pos = self.predict(h[pos_src], h[pos_dst])
        h_neg = self.predict(h[neg_src], h[neg_dst])
        return h_pos, h_neg

    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
78
79
        feat = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
80
81
82
83
84
85
86
        dataloader = dgl.dataloading.NodeDataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
87
88
89
90
            y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device,
                            pin_memory=args.pure_gpu)
            feat = feat.to(device)

91
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
92
                x = feat[input_nodes]
93
94
95
96
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
97
            feat = y
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        return y


def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
    rr = torch.zeros(src.shape[0])
    for start in tqdm.trange(0, src.shape[0], batch_size):
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
        pred = model.predict(h_src, h_dst).squeeze(-1)
        relevance = torch.zeros(*pred.shape, dtype=torch.bool)
        relevance[:, 0] = True
        rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
    return rr.mean()


def evaluate(model, edge_split, device, num_workers):
    with torch.no_grad():
        node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
        results = []
        for split in ['valid', 'test']:
            src = edge_split[split]['source_node'].to(device)
            dst = edge_split[split]['target_node'].to(device)
            neg_dst = edge_split[split]['target_node_neg'].to(device)
            results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
    return results


dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
130
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
131
132
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
133
134
135
edge_split = dataset.get_edge_split()

model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
136
137
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

138
sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
139
140
141
142
sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, exclude='reverse_id', reverse_eids=reverse_eids,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader(
143
144
        graph, seed_edges, sampler,
        device=device, batch_size=512, shuffle=True,
145
        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
146
147

durations = []
148
149
for epoch in range(10):
    model.train()
150
151
152
153
154
155
156
157
158
159
160
161
    t0 = time.time()
    for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
        x = blocks[0].srcdata['feat']
        pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
        pos_label = torch.ones_like(pos_score)
        neg_label = torch.zeros_like(neg_score)
        score = torch.cat([pos_score, neg_score])
        labels = torch.cat([pos_label, neg_label])
        loss = F.binary_cross_entropy_with_logits(score, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
162
        if (it + 1) % 20 == 0:
163
            mem = torch.cuda.max_memory_allocated() / 1000000
164
165
166
167
168
169
170
171
            print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
            if (it + 1) == 1000:
                tt = time.time()
                print(tt - t0)
                durations.append(tt - t0)
                break
    if epoch % 10 == 0:
        model.eval()
172
        valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12)
173
        print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
174
print(np.mean(durations[4:]), np.std(durations[4:]))