link_pred.py 7.9 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
import argparse

import dgl
import dgl.nn as dglnn
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
11
12
13
14
15
from dgl.dataloading import (
    as_edge_prediction_sampler,
    DataLoader,
    MultiLayerFullNeighborSampler,
    negative_sampler,
    NeighborSampler,
)
16
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
17

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18

19
20
def to_bidirected_with_reverse_mapping(g):
    """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
21
    is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
22
23
    """
    g_simple, mapping = dgl.to_simple(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
25
26
        dgl.add_reverse_edges(g), return_counts="count", writeback_mapping=True
    )
    c = g_simple.edata["count"]
27
    num_edges = g.num_edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
29
30
    mapping_offset = torch.zeros(
        g_simple.num_edges() + 1, dtype=g_simple.idtype
    )
31
32
33
    mapping_offset[1:] = c.cumsum(0)
    idx = mapping.argsort()
    idx_uniq = idx[mapping_offset[:-1]]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
34
35
36
    reverse_idx = torch.where(
        idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges
    )
37
    reverse_mapping = mapping[reverse_idx]
38
    # sanity check
39
40
41
42
43
    src1, dst1 = g_simple.edges()
    src2, dst2 = g_simple.find_edges(reverse_mapping)
    assert torch.equal(src1, dst2)
    assert torch.equal(src2, dst1)
    return g_simple, reverse_mapping
44

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45

46
class SAGE(nn.Module):
47
    def __init__(self, in_size, hid_size):
48
49
        super().__init__()
        self.layers = nn.ModuleList()
50
        # three-layer GraphSAGE-mean
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
51
52
53
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
54
        self.hid_size = hid_size
55
        self.predictor = nn.Sequential(
56
            nn.Linear(hid_size, hid_size),
57
            nn.ReLU(),
58
            nn.Linear(hid_size, hid_size),
59
            nn.ReLU(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
60
61
            nn.Linear(hid_size, 1),
        )
62
63
64
65
66
67
68

    def forward(self, pair_graph, neg_pair_graph, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
69
70
        pos_src, pos_dst = pair_graph.edges()
        neg_src, neg_dst = neg_pair_graph.edges()
71
72
        h_pos = self.predictor(h[pos_src] * h[pos_dst])
        h_neg = self.predictor(h[neg_src] * h[neg_dst])
73
74
        return h_pos, h_neg

75
76
    def inference(self, g, device, batch_size):
        """Layer-wise inference algorithm to compute GNN node embeddings."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77
78
        feat = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
79
        dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
80
81
82
83
84
85
86
87
88
89
90
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        pin_memory = buffer_device != device
91
        for l, layer in enumerate(self.layers):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
92
93
94
95
96
97
            y = torch.empty(
                g.num_nodes(),
                self.hid_size,
                device=buffer_device,
                pin_memory=pin_memory,
            )
98
            feat = feat.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
            for input_nodes, output_nodes, blocks in tqdm.tqdm(
                dataloader, desc="Inference"
            ):
102
                x = feat[input_nodes]
103
104
105
106
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                y[output_nodes] = h.to(buffer_device)
107
            feat = y
108
109
        return y

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
111
112
113

def compute_mrr(
    model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500
):
114
    """Compute Mean Reciprocal Rank (MRR) in batches."""
115
    rr = torch.zeros(src.shape[0])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116
    for start in tqdm.trange(0, src.shape[0], batch_size, desc="Evaluate"):
117
118
119
120
        end = min(start + batch_size, src.shape[0])
        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
        h_src = node_emb[src[start:end]][:, None, :].to(device)
        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
123
        pred = model.predictor(h_src * h_dst).squeeze(-1)
        input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
        rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
124
125
    return rr.mean()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
126

127
128
def evaluate(device, graph, edge_split, model, batch_size):
    model.eval()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129
    evaluator = Evaluator(name="ogbl-citation2")
130
    with torch.no_grad():
131
        node_emb = model.inference(graph, device, batch_size)
132
        results = []
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
133
134
135
136
137
138
139
140
141
        for split in ["valid", "test"]:
            src = edge_split[split]["source_node"].to(node_emb.device)
            dst = edge_split[split]["target_node"].to(node_emb.device)
            neg_dst = edge_split[split]["target_node_neg"].to(node_emb.device)
            results.append(
                compute_mrr(
                    model, evaluator, node_emb, src, dst, neg_dst, device
                )
            )
142
143
    return results

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
144

145
146
def train(args, device, g, reverse_eids, seed_edges, model):
    # create sampler & dataloader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
147
    sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=["feat"])
148
    sampler = as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
149
150
151
152
153
154
        sampler,
        exclude="reverse_id",
        reverse_eids=reverse_eids,
        negative_sampler=negative_sampler.Uniform(1),
    )
    use_uva = args.mode == "mixed"
155
    dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
156
157
158
159
160
161
162
163
164
165
        g,
        seed_edges,
        sampler,
        device=device,
        batch_size=512,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )
166
167
168
169
    opt = torch.optim.Adam(model.parameters(), lr=0.0005)
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
170
171
172
173
        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(
            dataloader
        ):
            x = blocks[0].srcdata["feat"]
174
175
176
177
178
179
180
181
182
183
            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
            score = torch.cat([pos_score, neg_score])
            pos_label = torch.ones_like(pos_score)
            neg_label = torch.zeros_like(neg_score)
            labels = torch.cat([pos_label, neg_label])
            loss = F.binary_cross_entropy_with_logits(score, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
184
185
186
187
            if (it + 1) == 1000:
                break
        print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it + 1)))

188

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
189
if __name__ == "__main__":
190
    parser = argparse.ArgumentParser()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191
192
193
194
195
196
197
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["cpu", "mixed", "puregpu"],
        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
198
199
    args = parser.parse_args()
    if not torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
200
201
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")
202
203

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
204
205
    print("Loading data")
    dataset = DglLinkPropPredDataset("ogbl-citation2")
206
    g = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
207
208
    g = g.to("cuda" if args.mode == "puregpu" else "cpu")
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
209
210
211
212
213
214
    g, reverse_eids = to_bidirected_with_reverse_mapping(g)
    reverse_eids = reverse_eids.to(device)
    seed_edges = torch.arange(g.num_edges()).to(device)
    edge_split = dataset.get_edge_split()

    # create GraphSAGE model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
215
    in_size = g.ndata["feat"].shape[1]
216
217
218
    model = SAGE(in_size, 256).to(device)

    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
219
    print("Training...")
220
221
222
    train(args, device, g, reverse_eids, seed_edges, model)

    # validate/test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
223
224
225
226
227
228
229
230
231
    print("Validation/Testing...")
    valid_mrr, test_mrr = evaluate(
        device, g, edge_split, model, batch_size=1000
    )
    print(
        "Validation MRR {:.4f}, Test MRR {:.4f}".format(
            valid_mrr.item(), test_mrr.item()
        )
    )