link_pred.py 6.72 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
9
import tqdm
10
11
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang.
# (This is a long-standing issue)
12
from ogb.linkproppred import DglLinkPropPredDataset
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
device = 'cuda'

def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
    is the reverse edge of edge ID ``i``.
    Does not work with graphs that have self-loops.
    """
    g_simple, mapping = dgl.to_simple(
        dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
    c = g_simple.edata['count']
    num_edges = g.num_edges()
    mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype)
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
    reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
    reverse_mapping = mapping[reverse_idx]

    # Correctness check
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
38
39

class SAGE(nn.Module):
40
    def __init__(self, in_feats, n_hidden):
41
        super().__init__()
42
        self.n_hidden = n_hidden
43
44
45
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
46
47
48
49
50
51
52
53
54
55
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.predictor = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1))

    def predict(self, h_src, h_dst):
        return self.predictor(h_src * h_dst)
56
57
58
59
60
61
62

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
        h_pos = self.predict(h[pos_src], h[pos_dst])
        h_neg = self.predict(h[neg_src], h[neg_dst])
        return h_pos, h_neg

    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
        dataloader = dgl.dataloading.NodeDataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
            y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = blocks[0].srcdata['h']
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
            g.ndata['h'] = y
        return y


def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
    rr = torch.zeros(src.shape[0])
    for start in tqdm.trange(0, src.shape[0], batch_size):
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
        pred = model.predict(h_src, h_dst).squeeze(-1)
        relevance = torch.zeros(*pred.shape, dtype=torch.bool)
        relevance[:, 0] = True
        rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
    return rr.mean()


def evaluate(model, edge_split, device, num_workers):
    with torch.no_grad():
        node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
        results = []
        for split in ['valid', 'test']:
            src = edge_split[split]['source_node'].to(device)
            dst = edge_split[split]['target_node'].to(device)
            neg_dst = edge_split[split]['target_node_neg'].to(device)
            results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
    return results


dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
121
122
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
123
124
125
edge_split = dataset.get_edge_split()

model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
126
127
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

128
sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
129
130
131
132
sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, exclude='reverse_id', reverse_eids=reverse_eids,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader(
133
134
        graph, seed_edges, sampler,
        device=device, batch_size=512, shuffle=True,
135
        drop_last=False, num_workers=0, use_uva=True)
136
137

durations = []
138
139
for epoch in range(10):
    model.train()
140
141
142
143
144
145
146
147
148
149
150
151
    t0 = time.time()
    for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
        x = blocks[0].srcdata['feat']
        pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
        pos_label = torch.ones_like(pos_score)
        neg_label = torch.zeros_like(neg_score)
        score = torch.cat([pos_score, neg_score])
        labels = torch.cat([pos_label, neg_label])
        loss = F.binary_cross_entropy_with_logits(score, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
152
        if (it + 1) % 20 == 0:
153
            mem = torch.cuda.max_memory_allocated() / 1000000
154
155
156
157
158
159
160
161
162
163
            print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
            if (it + 1) == 1000:
                tt = time.time()
                print(tt - t0)
                durations.append(tt - t0)
                break
    if epoch % 10 == 0:
        model.eval()
        valid_mrr, test_mrr = evaluate(model, edge_split, device, 12)
        print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
164
print(np.mean(durations[4:]), np.std(durations[4:]))