link_pred.py 7.48 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
7
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler, as_edge_prediction_sampler, negative_sampler
8
import tqdm
9
10
import argparse
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
11
12
13

def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
14
    is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
15
16
17
18
19
20
21
22
23
24
25
    """
    g_simple, mapping = dgl.to_simple(
        dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
    c = g_simple.edata['count']
    num_edges = g.num_edges()
    mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype)
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
    reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
    reverse_mapping = mapping[reverse_idx]
26
    # sanity check
27
28
29
30
31
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
32
33

class SAGE(nn.Module):
34
    def __init__(self, in_size, hid_size):
35
36
        super().__init__()
        self.layers = nn.ModuleList()
37
38
39
40
41
        # three-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.hid_size = hid_size
42
        self.predictor = nn.Sequential(
43
            nn.Linear(hid_size, hid_size),
44
            nn.ReLU(),
45
            nn.Linear(hid_size, hid_size),
46
            nn.ReLU(),
47
            nn.Linear(hid_size, 1))
48
49
50
51
52
53
54

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
55
56
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
57
58
        h_pos = self.predictor(h[pos_src] * h[pos_dst])
        h_neg = self.predictor(h[neg_src] * h[neg_dst])
59
60
        return h_pos, h_neg

61
62
    def inference(self, g, device, batch_size):
        """Layer-wise inference algorithm to compute GNN node embeddings."""
63
        feat = g.ndata['feat']
64
65
66
67
68
69
70
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
        dataloader = DataLoader(
            g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
            batch_size=batch_size, shuffle=False, drop_last=False,
            num_workers=0)
        buffer_device = torch.device('cpu')
        pin_memory = (buffer_device != device)
71
        for l, layer in enumerate(self.layers):
72
73
            y = torch.empty(g.num_nodes(), self.hid_size, device=buffer_device,
                            pin_memory=pin_memory)
74
            feat = feat.to(device)
75
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc='Inference'):
76
                x = feat[input_nodes]
77
78
79
80
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
81
            feat = y
82
83
        return y

84
85
def compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500):
    """Compute Mean Reciprocal Rank (MRR) in batches."""
86
    rr = torch.zeros(src.shape[0])
87
    for start in tqdm.trange(0, src.shape[0], batch_size, desc='Evaluate'):
88
89
90
91
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
92
93
94
        pred = model.predictor(h_src*h_dst).squeeze(-1)
        input_dict = {'y_pred_pos': pred[:,0], 'y_pred_neg': pred[:,1:]}
        rr[start:end] = evaluator.eval(input_dict)['mrr_list']
95
96
    return rr.mean()

97
98
99
def evaluate(device, graph, edge_split, model, batch_size):
    model.eval()
    evaluator = Evaluator(name='ogbl-citation2')
100
    with torch.no_grad():
101
        node_emb = model.inference(graph, device, batch_size)
102
103
        results = []
        for split in ['valid', 'test']:
104
105
106
            src = edge_split[split]['source_node'].to(node_emb.device)
            dst = edge_split[split]['target_node'].to(node_emb.device)
            neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
107
            results.append(compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device))
108
109
    return results

110
111
112
113
def train(args, device, g, reverse_eids, seed_edges, model):
    # create sampler & dataloader
    sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
    sampler = as_edge_prediction_sampler(
114
        sampler, exclude='reverse_id', reverse_eids=reverse_eids,
115
116
117
118
        negative_sampler=negative_sampler.Uniform(1))
    use_uva = (args.mode == 'mixed')
    dataloader = DataLoader(
        g, seed_edges, sampler,
119
        device=device, batch_size=512, shuffle=True,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        drop_last=False, num_workers=0, use_uva=use_uva)
    opt = torch.optim.Adam(model.parameters(), lr=0.0005)
    for epoch in range(10):
        model.train()
        total_loss = 0
        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
            x = blocks[0].srcdata['feat']
            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
            score = torch.cat([pos_score, neg_score])
            pos_label = torch.ones_like(pos_score)
            neg_label = torch.zeros_like(neg_score)
            labels = torch.cat([pos_label, neg_label])
            loss = F.binary_cross_entropy_with_logits(score, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
            if (it+1) == 1000: break
        print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it+1)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
                        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
                             "'puregpu' for pure-GPU training.")
    args = parser.parse_args()
    if not torch.cuda.is_available():
        args.mode = 'cpu'
    print(f'Training in {args.mode} mode.')

    # load and preprocess dataset
    print('Loading data')
    dataset = DglLinkPropPredDataset('ogbl-citation2')
    g = dataset[0]
    g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
    device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')
    g, reverse_eids = to_bidirected_with_reverse_mapping(g)
    reverse_eids = reverse_eids.to(device)
    seed_edges = torch.arange(g.num_edges()).to(device)
    edge_split = dataset.get_edge_split()

    # create GraphSAGE model
    in_size = g.ndata['feat'].shape[1]
    model = SAGE(in_size, 256).to(device)

    # model training
    print('Training...')
    train(args, device, g, reverse_eids, seed_edges, model)

    # validate/test the model
    print('Validation/Testing...')
    valid_mrr, test_mrr = evaluate(device, g, edge_split, model, batch_size=1000)
    print('Validation MRR {:.4f}, Test MRR {:.4f}'.format(valid_mrr.item(),test_mrr.item()))