"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f088027e937b2ee1acef1f6b2776b7b2fee7ffd6"
Unverified Commit 533afa85 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Graphsage link prediction example refactor (#4526)

* Refactor link pred example for graphsage

* Use ogb evaluator + README update

* Update

* Add comments
parent db64cc37
...@@ -54,8 +54,13 @@ python3 lightning/node_classification.py ...@@ -54,8 +54,13 @@ python3 lightning/node_classification.py
### Minibatch training for link prediction ### Minibatch training for link prediction
Train w/ mini-batch sampling for link prediction on OGB-Citation2: Train w/ mini-batch sampling for link prediction on OGB-citation2:
```bash ```bash
python3 link_pred.py python3 link_pred.py
``` ```
Results (10 epochs):
```
Test MRR: 0.7386
```
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler, as_edge_prediction_sampler, negative_sampler
import numpy as np
import tqdm import tqdm
# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang. import argparse
# (This is a long-standing issue) from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from ogb.linkproppred import DglLinkPropPredDataset
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
help='Perform both sampling and training on GPU.')
args = parser.parse_args()
device = 'cuda'
def to_bidirected_with_reverse_mapping(g): def to_bidirected_with_reverse_mapping(g):
"""Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]`` """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
is the reverse edge of edge ID ``i``. is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
Does not work with graphs that have self-loops.
""" """
g_simple, mapping = dgl.to_simple( g_simple, mapping = dgl.to_simple(
dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True) dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True)
...@@ -34,8 +23,7 @@ def to_bidirected_with_reverse_mapping(g): ...@@ -34,8 +23,7 @@ def to_bidirected_with_reverse_mapping(g):
idx_uniq = idx[mapping_offset[:-1]] idx_uniq = idx[mapping_offset[:-1]]
reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges) reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges)
reverse_mapping = mapping[reverse_idx] reverse_mapping = mapping[reverse_idx]
# sanity check
# Correctness check
src1, dst1 = g_simple.edges() src1, dst1 = g_simple.edges()
src2, dst2 = g_simple.find_edges(reverse_mapping) src2, dst2 = g_simple.find_edges(reverse_mapping)
assert torch.equal(src1, dst2) assert torch.equal(src1, dst2)
...@@ -43,22 +31,20 @@ def to_bidirected_with_reverse_mapping(g): ...@@ -43,22 +31,20 @@ def to_bidirected_with_reverse_mapping(g):
return g_simple, reverse_mapping return g_simple, reverse_mapping
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden): def __init__(self, in_size, hid_size):
super().__init__() super().__init__()
self.n_hidden = n_hidden
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) # three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.hid_size = hid_size
self.predictor = nn.Sequential( self.predictor = nn.Sequential(
nn.Linear(n_hidden, n_hidden), nn.Linear(hid_size, hid_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(n_hidden, n_hidden), nn.Linear(hid_size, hid_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(n_hidden, 1)) nn.Linear(hid_size, 1))
def predict(self, h_src, h_dst):
return self.predictor(h_src * h_dst)
def forward(self, pair_graph, neg_pair_graph, blocks, x): def forward(self, pair_graph, neg_pair_graph, blocks, x):
h = x h = x
...@@ -68,27 +54,25 @@ class SAGE(nn.Module): ...@@ -68,27 +54,25 @@ class SAGE(nn.Module):
h = F.relu(h) h = F.relu(h)
pos_src, pos_dst = pair_graph.edges() pos_src, pos_dst = pair_graph.edges()
neg_src, neg_dst = neg_pair_graph.edges() neg_src, neg_dst = neg_pair_graph.edges()
h_pos = self.predict(h[pos_src], h[pos_dst]) h_pos = self.predictor(h[pos_src] * h[pos_dst])
h_neg = self.predict(h[neg_src], h[neg_dst]) h_neg = self.predictor(h[neg_src] * h[neg_dst])
return h_pos, h_neg return h_pos, h_neg
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size):
# The difference between this inference function and the one in the official """Layer-wise inference algorithm to compute GNN node embeddings."""
# example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.DataLoader( dataloader = DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers) batch_size=batch_size, shuffle=False, drop_last=False,
if buffer_device is None: num_workers=0)
buffer_device = device buffer_device = torch.device('cpu')
pin_memory = (buffer_device != device)
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(), self.n_hidden, device=buffer_device, y = torch.empty(g.num_nodes(), self.hid_size, device=buffer_device,
pin_memory=args.pure_gpu) pin_memory=pin_memory)
feat = feat.to(device) feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc='Inference'):
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = feat[input_nodes] x = feat[input_nodes]
h = layer(blocks[0], x) h = layer(blocks[0], x)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
...@@ -97,78 +81,92 @@ class SAGE(nn.Module): ...@@ -97,78 +81,92 @@ class SAGE(nn.Module):
feat = y feat = y
return y return y
def compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500):
def compute_mrr(model, node_emb, src, dst, neg_dst, device, batch_size=500): """Compute Mean Reciprocal Rank (MRR) in batches."""
rr = torch.zeros(src.shape[0]) rr = torch.zeros(src.shape[0])
for start in tqdm.trange(0, src.shape[0], batch_size): for start in tqdm.trange(0, src.shape[0], batch_size, desc='Evaluate'):
end = min(start + batch_size, src.shape[0]) end = min(start + batch_size, src.shape[0])
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1) all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
h_src = node_emb[src[start:end]][:, None, :].to(device) h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device) h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predict(h_src, h_dst).squeeze(-1) pred = model.predictor(h_src*h_dst).squeeze(-1)
relevance = torch.zeros(*pred.shape, dtype=torch.bool).to(pred.device) input_dict = {'y_pred_pos': pred[:,0], 'y_pred_neg': pred[:,1:]}
relevance[:, 0] = True rr[start:end] = evaluator.eval(input_dict)['mrr_list']
rr[start:end] = MF.retrieval_reciprocal_rank(pred, relevance)
return rr.mean() return rr.mean()
def evaluate(device, graph, edge_split, model, batch_size):
def evaluate(model, edge_split, device, num_workers): model.eval()
evaluator = Evaluator(name='ogbl-citation2')
with torch.no_grad(): with torch.no_grad():
node_emb = model.inference(graph, device, 4096, num_workers, 'cpu') node_emb = model.inference(graph, device, batch_size)
results = [] results = []
for split in ['valid', 'test']: for split in ['valid', 'test']:
src = edge_split[split]['source_node'].to(node_emb.device) src = edge_split[split]['source_node'].to(node_emb.device)
dst = edge_split[split]['target_node'].to(node_emb.device) dst = edge_split[split]['target_node'].to(node_emb.device)
neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device) neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device)
results.append(compute_mrr(model, node_emb, src, dst, neg_dst, device)) results.append(compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device))
return results return results
def train(args, device, g, reverse_eids, seed_edges, model):
dataset = DglLinkPropPredDataset('ogbl-citation2') # create sampler & dataloader
graph = dataset[0] sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
graph, reverse_eids = to_bidirected_with_reverse_mapping(graph) sampler = as_edge_prediction_sampler(
graph = graph.to('cuda' if args.pure_gpu else 'cpu')
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(graph.num_edges()).to(device)
edge_split = dataset.get_edge_split()
model = SAGE(graph.ndata['feat'].shape[1], 256).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
sampler = dgl.dataloading.NeighborSampler([15, 10, 5], prefetch_node_feats=['feat'])
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, exclude='reverse_id', reverse_eids=reverse_eids, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(1)) negative_sampler=negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader( use_uva = (args.mode == 'mixed')
graph, seed_edges, sampler, dataloader = DataLoader(
g, seed_edges, sampler,
device=device, batch_size=512, shuffle=True, device=device, batch_size=512, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu) drop_last=False, num_workers=0, use_uva=use_uva)
opt = torch.optim.Adam(model.parameters(), lr=0.0005)
durations = [] for epoch in range(10):
for epoch in range(10): model.train()
model.train() total_loss = 0
t0 = time.time() for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader):
for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader): x = blocks[0].srcdata['feat']
x = blocks[0].srcdata['feat'] pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x) score = torch.cat([pos_score, neg_score])
pos_label = torch.ones_like(pos_score) pos_label = torch.ones_like(pos_score)
neg_label = torch.zeros_like(neg_score) neg_label = torch.zeros_like(neg_score)
score = torch.cat([pos_score, neg_score]) labels = torch.cat([pos_label, neg_label])
labels = torch.cat([pos_label, neg_label]) loss = F.binary_cross_entropy_with_logits(score, labels)
loss = F.binary_cross_entropy_with_logits(score, labels) opt.zero_grad()
opt.zero_grad() loss.backward()
loss.backward() opt.step()
opt.step() total_loss += loss.item()
if (it + 1) % 20 == 0: if (it+1) == 1000: break
mem = torch.cuda.max_memory_allocated() / 1000000 print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it+1)))
print('Loss', loss.item(), 'GPU Mem', mem, 'MB')
if (it + 1) == 1000: if __name__ == '__main__':
tt = time.time() parser = argparse.ArgumentParser()
print(tt - t0) parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
durations.append(tt - t0) help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
break "'puregpu' for pure-GPU training.")
if epoch % 10 == 0: args = parser.parse_args()
model.eval() if not torch.cuda.is_available():
valid_mrr, test_mrr = evaluate(model, edge_split, device, 0 if args.pure_gpu else 12) args.mode = 'cpu'
print('Validation MRR:', valid_mrr.item(), 'Test MRR:', test_mrr.item()) print(f'Training in {args.mode} mode.')
print(np.mean(durations[4:]), np.std(durations[4:]))
# load and preprocess dataset
print('Loading data')
dataset = DglLinkPropPredDataset('ogbl-citation2')
g = dataset[0]
g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')
g, reverse_eids = to_bidirected_with_reverse_mapping(g)
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(g.num_edges()).to(device)
edge_split = dataset.get_edge_split()
# create GraphSAGE model
in_size = g.ndata['feat'].shape[1]
model = SAGE(in_size, 256).to(device)
# model training
print('Training...')
train(args, device, g, reverse_eids, seed_edges, model)
# validate/test the model
print('Validation/Testing...')
valid_mrr, test_mrr = evaluate(device, g, edge_split, model, batch_size=1000)
print('Validation MRR {:.4f}, Test MRR {:.4f}'.format(valid_mrr.item(),test_mrr.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment