"docs/source/vscode:/vscode.git/clone" did not exist on "7fe6d0c85732d57a95cd2260fce1a2e1fd93489c"
Unverified Commit 533afa85 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Graphsage link prediction example refactor (#4526)

* Refactor link pred example for graphsage

* Use ogb evaluator + README update

* Update

* Add comments
parent db64cc37
......@@ -54,8 +54,13 @@ python3 lightning/node_classification.py
### Minibatch training for link prediction
Train w/ mini-batch sampling for link prediction on OGB-Citation2:
Train w/ mini-batch sampling for link prediction on OGB-citation2:
```bash
python3 link_pred.py
```
Results (10 epochs):
```
Test MRR: 0.7386
```
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler, as_edge_prediction_sampler, negative_sampler
import tqdm
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang.
# (This is a long-standing issue)
from ogb.linkproppred import DglLinkPropPredDataset
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
device = 'cuda'
import argparse
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
def to_bidirected_with_reverse_mapping(g):
"""Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
is the reverse edge of edge ID ``i``.
Does not work with graphs that have self-loops.
is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
"""
g_simple, mapping = dgl.to_simple(
dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
......@@ -34,8 +23,7 @@ def to_bidirected_with_reverse_mapping(g):
idx_uniq = idx[mapping_offset[:-1]]
reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
reverse_mapping = mapping[reverse_idx]
# Correctness check
# sanity check
src1, dst1 = g_simple.edges()
src2, dst2 = g_simple.find_edges(reverse_mapping)
assert torch.equal(src1, dst2)
......@@ -43,22 +31,20 @@ def to_bidirected_with_reverse_mapping(g):
return g_simple, reverse_mapping
class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden):
def __init__(self, in_size, hid_size):
super().__init__()
self.n_hidden = n_hidden
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
# three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.hid_size = hid_size
self.predictor = nn.Sequential(
nn.Linear(n_hidden, n_hidden),
nn.Linear(hid_size, hid_size),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.Linear(hid_size, hid_size),
nn.ReLU(),
nn.Linear(n_hidden, 1))
def predict(self, h_src, h_dst):
return self.predictor(h_src * h_dst)
nn.Linear(hid_size, 1))
def forward(self, pair_graph, neg_pair_graph, blocks, x):
h = x
......@@ -68,27 +54,25 @@ class SAGE(nn.Module):
h = F.relu(h)
pos_src, pos_dst = pair_graph.edges()
neg_src, neg_dst = neg_pair_graph.edges()
h_pos = self.predict(h[pos_src], h[pos_dst])
h_neg = self.predict(h[neg_src], h[neg_dst])
h_pos = self.predictor(h[pos_src] * h[pos_dst])
h_neg = self.predictor(h[neg_src] * h[neg_dst])
return h_pos, h_neg
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
def inference(self, g, device, batch_size):
"""Layer-wise inference algorithm to compute GNN node embeddings."""
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.DataLoader(
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
if buffer_device is None:
buffer_device = device
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=0)
buffer_device = torch.device('cpu')
pin_memory = (buffer_device != device)
for l, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device,
pin_memory=args.pure_gpu)
y = torch.empty(g.num_nodes(), self.hid_size, device=buffer_device,
pin_memory=pin_memory)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc='Inference'):
x = feat[input_nodes]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
......@@ -97,78 +81,92 @@ class SAGE(nn.Module):
feat = y
return y
def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500):
def compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500):
"""Compute Mean Reciprocal Rank (MRR) in batches."""
rr = torch.zeros(src.shape[0])
for start in tqdm.trange(0, src.shape[0], batch_size):
for start in tqdm.trange(0, src.shape[0], batch_size, desc='Evaluate'):
end = min(start + batch_size, src.shape[0])
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predict(h_src, h_dst).squeeze(-1)
relevance = torch.zeros(*pred.shape, dtype=torch.bool).to(pred.device)
relevance[:, 0] = True
rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
pred = model.predictor(h_src*h_dst).squeeze(-1)
input_dict = {'y_pred_pos': pred[:,0], 'y_pred_neg': pred[:,1:]}
rr[start:end] = evaluator.eval(input_dict)['mrr_list']
return rr.mean()
def evaluate(model, edge_split, device, num_workers):
def evaluate(device, graph, edge_split, model, batch_size):
model.eval()
evaluator = Evaluator(name='ogbl-citation2')
with torch.no_grad():
node_emb = model.inference(graph, device, 4096, num_workers, 'cpu')
node_emb = model.inference(graph, device, batch_size)
results = []
for split in ['valid', 'test']:
src = edge_split[split]['source_node'].to(node_emb.device)
dst = edge_split[split]['target_node'].to(node_emb.device)
neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device))
results.append(compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device))
return results
dataset = DglLinkPropPredDataset('ogbl-citation2')
graph = dataset[0]
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph)
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
edge_split = dataset.get_edge_split()
model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
sampler = dgl.dataloading.as_edge_prediction_sampler(
def train(args, device, g, reverse_eids, seed_edges, model):
# create sampler & dataloader
sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
sampler = as_edge_prediction_sampler(
sampler, exclude='reverse_id', reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader(
graph, seed_edges, sampler,
negative_sampler=negative_sampler.Uniform(1))
use_uva = (args.mode == 'mixed')
dataloader = DataLoader(
g, seed_edges, sampler,
device=device, batch_size=512, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
durations = []
for epoch in range(10):
drop_last=False, num_workers=0, use_uva=use_uva)
opt = torch.optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(10):
model.train()
t0 = time.time()
total_loss = 0
for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
x = blocks[0].srcdata['feat']
pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
score = torch.cat([pos_score, neg_score])
pos_label = torch.ones_like(pos_score)
neg_label = torch.zeros_like(neg_score)
score = torch.cat([pos_score, neg_score])
labels = torch.cat([pos_label, neg_label])
loss = F.binary_cross_entropy_with_logits(score, labels)
opt.zero_grad()
loss.backward()
opt.step()
if (it + 1) % 20 == 0:
mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
if (it + 1) == 1000:
tt = time.time()
print(tt - t0)
durations.append(tt - t0)
break
if epoch % 10 == 0:
model.eval()
valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12)
print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item())
print(np.mean(durations[4:]), np.std(durations[4:]))
total_loss += loss.item()
if (it+1) == 1000: break
print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it+1)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.")
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = 'cpu'
print(f'Training in {args.mode} mode.')
# load and preprocess dataset
print('Loading data')
dataset = DglLinkPropPredDataset('ogbl-citation2')
g = dataset[0]
g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')
g, reverse_eids = to_bidirected_with_reverse_mapping(g)
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(g.num_edges()).to(device)
edge_split = dataset.get_edge_split()
# create GraphSAGE model
in_size = g.ndata['feat'].shape[1]
model = SAGE(in_size, 256).to(device)
# model training
print('Training...')
train(args, device, g, reverse_eids, seed_edges, model)
# validate/test the model
print('Validation/Testing...')
valid_mrr, test_mrr = evaluate(device, g, edge_split, model, batch_size=1000)
print('Validation MRR {:.4f}, Test MRR {:.4f}'.format(valid_mrr.item(),test_mrr.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment