Unverified Commit 72cfb934 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] RGCN link prediction example refactor (#4548)

parent 14cafe73
...@@ -15,7 +15,7 @@ class GAT(nn.Module): ...@@ -15,7 +15,7 @@ class GAT(nn.Module):
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu)) self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], hid_size, heads[1], residual=True, activation=F.elu)) self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], hid_size, heads[1], residual=True, activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[1], out_size, heads[2], residual=True, activation=None)) self.gat_layers.append(dglnn.GATConv(hid_size*heads[1], out_size, heads[2], residual=True, activation=None))
def forward(self, g, inputs): def forward(self, g, inputs):
h = inputs h = inputs
for i, layer in enumerate(self.gat_layers): for i, layer in enumerate(self.gat_layers):
...@@ -43,13 +43,13 @@ def evaluate_in_batches(dataloader, device, model): ...@@ -43,13 +43,13 @@ def evaluate_in_batches(dataloader, device, model):
score = evaluate(batched_graph, features, labels, model) score = evaluate(batched_graph, features, labels, model)
total_score += score total_score += score
return total_score / (batch_id + 1) # return average score return total_score / (batch_id + 1) # return average score
def train(train_dataloader, val_dataloader, device, model): def train(train_dataloader, val_dataloader, device, model):
# define loss function and optimizer # define loss function and optimizer
loss_fcn = nn.BCEWithLogitsLoss() loss_fcn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0) optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)
# training loop # training loop
for epoch in range(400): for epoch in range(400):
model.train() model.train()
logits = [] logits = []
...@@ -66,27 +66,27 @@ def train(train_dataloader, val_dataloader, device, model): ...@@ -66,27 +66,27 @@ def train(train_dataloader, val_dataloader, device, model):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
print("Epoch {:05d} | Loss {:.4f} |". format(epoch, total_loss / (batch_id + 1) )) print("Epoch {:05d} | Loss {:.4f} |". format(epoch, total_loss / (batch_id + 1) ))
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
avg_score = evaluate_in_batches(val_dataloader, device, model) # evaluate F1-score instead of loss avg_score = evaluate_in_batches(val_dataloader, device, model) # evaluate F1-score instead of loss
print(" Acc. (F1-score) {:.4f} ". format(avg_score)) print(" Acc. (F1-score) {:.4f} ". format(avg_score))
if __name__ == '__main__': if __name__ == '__main__':
print(f'Training PPI Dataset with DGL built-in GATConv module.') print(f'Training PPI Dataset with DGL built-in GATConv module.')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load and preprocess datasets # load and preprocess datasets
train_dataset = PPIDataset(mode='train') train_dataset = PPIDataset(mode='train')
val_dataset = PPIDataset(mode='valid') val_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test') test_dataset = PPIDataset(mode='test')
features = train_dataset[0].ndata['feat'] features = train_dataset[0].ndata['feat']
# create GAT model # create GAT model
in_size = features.shape[1] in_size = features.shape[1]
out_size = train_dataset.num_labels out_size = train_dataset.num_labels
model = GAT(in_size, 256, out_size, heads=[4,4,6]).to(device) model = GAT(in_size, 256, out_size, heads=[4,4,6]).to(device)
# model training # model training
print('Training...') print('Training...')
train_dataloader = GraphDataLoader(train_dataset, batch_size=2) train_dataloader = GraphDataLoader(train_dataset, batch_size=2)
......
...@@ -25,9 +25,9 @@ python3 train.py --dataset MUTAG ...@@ -25,9 +25,9 @@ python3 train.py --dataset MUTAG
Summary (10-fold cross-validation) Summary (10-fold cross-validation)
------- -------
| Dataset | Result | Dataset | Result
| ------------- | ------- | ------------- | -------
| MUTAG | ~89.4 | MUTAG | ~89.4
| PTC | ~68.5 | PTC | ~68.5
| NCI1 | ~82.9 | NCI1 | ~82.9
| PROTEINS | ~74.1 | PROTEINS | ~74.1
...@@ -57,7 +57,7 @@ class SAGE(nn.Module): ...@@ -57,7 +57,7 @@ class SAGE(nn.Module):
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device) y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y feat = y
return y return y
def evaluate(model, graph, dataloader): def evaluate(model, graph, dataloader):
model.eval() model.eval()
ys = [] ys = []
...@@ -96,7 +96,7 @@ def train(args, device, g, dataset, model): ...@@ -96,7 +96,7 @@ def train(args, device, g, dataset, model):
use_uva=use_uva) use_uva=use_uva)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10): for epoch in range(10):
model.train() model.train()
total_loss = 0 total_loss = 0
...@@ -122,7 +122,7 @@ if __name__ == '__main__': ...@@ -122,7 +122,7 @@ if __name__ == '__main__':
if not torch.cuda.is_available(): if not torch.cuda.is_available():
args.mode = 'cpu' args.mode = 'cpu'
print(f'Training in {args.mode} mode.') print(f'Training in {args.mode} mode.')
# load and preprocess dataset # load and preprocess dataset
print('Loading data') print('Loading data')
dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products')) dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
......
...@@ -28,23 +28,22 @@ For mini-batch training, run with the following (available datasets are the same ...@@ -28,23 +28,22 @@ For mini-batch training, run with the following (available datasets are the same
```bash ```bash
python3 entity_sample.py --dataset aifb python3 entity_sample.py --dataset aifb
``` ```
For multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma) For multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma)
```bash ```bash
python3 entity_sample_multi_gpu.py --dataset aifb --gpu 0,1 python3 entity_sample_multi_gpu.py --dataset aifb --gpu 0,1
``` ```
### Link Prediction ### Link Prediction
FB15k-237 in RAW-MRR
``` Run with the following for link prediction on dataset FB15k-237 with filtered-MRR
python link.py --gpu 0 --eval-protocol raw
``` ```bash
FB15k-237 in Filtered-MRR python link.py
```
python link.py --gpu 0 --eval-protocol filtered
``` ```
> **_NOTE:_** By default, we use uniform edge sampling instead of neighbor-based edge sampling as in [author's code](https://github.com/MichSchli/RelationPrediction). In practice, we find that it can achieve similar MRR.
Summary Summary
------- -------
### Entity Classification ### Entity Classification
...@@ -55,3 +54,8 @@ Summary ...@@ -55,3 +54,8 @@ Summary
| mutag | ~0.70 | ~0.50 | mutag | ~0.70 | ~0.50
| bgs | ~0.86 | ~0.64 | bgs | ~0.86 | ~0.64
| am | ~0.78 | ~0.42 | am | ~0.78 | ~0.42
### Link Prediction
| Dataset | Best MRR
| ------------- | -------
| FB15k-237 | ~0.2439
...@@ -16,13 +16,13 @@ class RGCN(nn.Module): ...@@ -16,13 +16,13 @@ class RGCN(nn.Module):
num_bases=num_rels, self_loop=False) num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis', self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False) num_bases=num_rels, self_loop=False)
def forward(self, g): def forward(self, g):
x = self.emb.weight x = self.emb.weight
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])) h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm']) h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return h return h
def evaluate(g, target_idx, labels, test_mask, model): def evaluate(g, target_idx, labels, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval() model.eval()
......
...@@ -17,13 +17,13 @@ class RGCN(nn.Module): ...@@ -17,13 +17,13 @@ class RGCN(nn.Module):
num_bases=num_rels, self_loop=False) num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis', self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False) num_bases=num_rels, self_loop=False)
def forward(self, g): def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID]) x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm'])) h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm']) h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h return h
def evaluate(model, label, dataloader, inv_target): def evaluate(model, label, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
...@@ -68,7 +68,7 @@ def train(device, g, target_idx, labels, train_mask, model): ...@@ -68,7 +68,7 @@ def train(device, g, target_idx, labels, train_mask, model):
acc = evaluate(model, labels, val_loader, inv_target) acc = evaluate(model, labels, val_loader, inv_target)
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} " print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
. format(epoch, total_loss / (it+1), acc)) . format(epoch, total_loss / (it+1), acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling') parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
parser.add_argument("--dataset", type=str, default="aifb", parser.add_argument("--dataset", type=str, default="aifb",
...@@ -105,12 +105,12 @@ if __name__ == '__main__': ...@@ -105,12 +105,12 @@ if __name__ == '__main__':
# find the mapping (inv_target) from global node IDs to type-specific node IDs # find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device) inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device) inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model) train(device, g, target_idx, labels, train_mask, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
......
""" import numpy as np
Differences compared to MichSchli/RelationPrediction import torch
* Report raw metrics instead of filtered metrics.
* By default, we use uniform edge sampling instead of neighbor-based edge
sampling used in author's code. In practice, we find it achieves similar MRR.
"""
import argparse
import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.data.knowledge_graph import FB15k237Dataset from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import RelGraphConv
import tqdm
# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
src, dst = g.edges()
sub_src = src[mask]
sub_dst = dst[mask]
sub_rel = g.edata['etype'][mask]
if bidirected:
sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat([sub_dst, sub_src])
sub_rel = torch.cat([sub_rel, sub_rel + num_rels])
sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
sub_g.edata[dgl.ETYPE] = sub_rel
return sub_g
class GlobalUniform:
def __init__(self, g, sample_size):
self.sample_size = sample_size
self.eids = np.arange(g.num_edges())
def sample(self):
return torch.from_numpy(np.random.choice(self.eids, self.sample_size))
class NegativeSampler:
def __init__(self, k=10): # negative sampling rate = 10
self.k = k
def sample(self, pos_samples, num_nodes):
batch_size = len(pos_samples)
neg_batch_size = batch_size * self.k
neg_samples = np.tile(pos_samples, (self.k, 1))
values = np.random.randint(num_nodes, size=neg_batch_size)
choices = np.random.uniform(size=neg_batch_size)
subj = choices > 0.5
obj = choices <= 0.5
neg_samples[subj, 0] = values[subj]
neg_samples[obj, 2] = values[obj]
samples = np.concatenate((pos_samples, neg_samples))
# binary labels indicating positive and negative samples
labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
labels[:batch_size] = 1
return torch.from_numpy(samples), torch.from_numpy(labels)
class SubgraphIterator:
def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
self.g = g
self.num_rels = num_rels
self.sample_size = sample_size
self.num_epochs = num_epochs
self.pos_sampler = GlobalUniform(g, sample_size)
self.neg_sampler = NegativeSampler()
from link_utils import preprocess, SubgraphIterator, calc_mrr def __len__(self):
from model import RGCN return self.num_epochs
def __getitem__(self, i):
eids = self.pos_sampler.sample()
src, dst = self.g.find_edges(eids)
src, dst = src.numpy(), dst.numpy()
rel = self.g.edata[dgl.ETYPE][eids].numpy()
# relabel nodes to have consecutive node IDs
uniq_v, edges = np.unique((src, dst), return_inverse=True)
num_nodes = len(uniq_v)
# edges is the concatenation of src, dst with relabeled ID
src, dst = np.reshape(edges, (2, -1))
relabeled_data = np.stack((src, rel, dst)).transpose()
samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)
# use only half of the positive edges
chosen_ids = np.random.choice(np.arange(self.sample_size),
size=int(self.sample_size / 2),
replace=False)
src = src[chosen_ids]
dst = dst[chosen_ids]
rel = rel[chosen_ids]
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + self.num_rels))
sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = torch.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels
class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, num_rels):
super().__init__()
# two-layer RGCN
self.emb = nn.Embedding(num_nodes, h_dim)
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
num_bases=100, self_loop=True)
self.conv2 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
num_bases=100, self_loop=True)
self.dropout = nn.Dropout(0.2)
def forward(self, g, nids):
x = self.emb(nids)
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
h = self.dropout(h)
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return self.dropout(h)
class LinkPredict(nn.Module): class LinkPredict(nn.Module):
def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01): def __init__(self, num_nodes, num_rels, h_dim = 500, reg_param=0.01):
super(LinkPredict, self).__init__() super().__init__()
self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd", self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
num_bases=num_bases, dropout=dropout, self_loop=True)
self.dropout = nn.Dropout(dropout)
self.reg_param = reg_param self.reg_param = reg_param
self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim)) self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation, nn.init.xavier_uniform_(self.w_relation,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
def calc_score(self, embedding, triplets): def calc_score(self, embedding, triplets):
# DistMult
s = embedding[triplets[:,0]] s = embedding[triplets[:,0]]
r = self.w_relation[triplets[:,1]] r = self.w_relation[triplets[:,1]]
o = embedding[triplets[:,2]] o = embedding[triplets[:,2]]
score = th.sum(s * r * o, dim=1) score = torch.sum(s * r * o, dim=1)
return score return score
def forward(self, g, nids): def forward(self, g, nids):
return self.dropout(self.rgcn(g, nids=nids)) return self.rgcn(g, nids)
def regularization_loss(self, embedding): def regularization_loss(self, embedding):
return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2)) return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))
def get_loss(self, embed, triplets, labels): def get_loss(self, embed, triplets, labels):
# each row in the triplets is a 3-tuple of (source, relation, destination) # each row in the triplets is a 3-tuple of (source, relation, destination)
...@@ -48,36 +144,67 @@ class LinkPredict(nn.Module): ...@@ -48,36 +144,67 @@ class LinkPredict(nn.Module):
reg_loss = self.regularization_loss(embed) reg_loss = self.regularization_loss(embed)
return predict_loss + self.reg_param * reg_loss return predict_loss + self.reg_param * reg_loss
def main(args): def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
data = FB15k237Dataset(reverse=False) """Get candidate heads or tails to score"""
graph = data[0] target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
num_nodes = graph.num_nodes() # Add the ground truth node first
num_rels = data.num_rels if filter_o:
candidate_nodes = [target_o]
train_g, test_g = preprocess(graph, num_rels) else:
test_nids = th.arange(0, num_nodes) candidate_nodes = [target_s]
test_mask = graph.edata['test_mask'] for e in range(num_nodes):
subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler) triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0]) # Do not consider a node if it leads to a real triplet
if triplet not in triplets_to_filter:
candidate_nodes.append(e)
return torch.LongTensor(candidate_nodes)
# Prepare data for metric computation def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
src, dst = graph.edges() """Perturb subject or object in the triplets"""
triplets = th.stack([src, graph.edata['etype'], dst], dim=1) num_nodes = emb.shape[0]
ranks = []
for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
candidate_nodes = filter(triplets_to_filter, target_s, target_r,
target_o, num_nodes, filter_o=filter_o)
if filter_o:
emb_s = emb[target_s]
emb_o = emb[candidate_nodes]
else:
emb_s = emb[candidate_nodes]
emb_o = emb[target_o]
target_idx = 0
emb_r = w[target_r]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
model = LinkPredict(num_nodes, num_rels) _, indices = torch.sort(scores, descending=True)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2) rank = int((indices == target_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)
if args.gpu >= 0 and th.cuda.is_available(): def calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True):
device = th.device(args.gpu) with torch.no_grad():
else: test_triplets = triplets_to_filter[test_mask]
device = th.device('cpu') s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
model = model.to(device) test_size = len(s)
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
triplets_to_filter, filter_o=False)
ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
test_size, triplets_to_filter)
ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float()).item()
return mrr
def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
best_mrr = 0 best_mrr = 0
model_state_file = 'model_state.pth' for epoch, batch_data in enumerate(dataloader): # single graph batch
for epoch, batch_data in enumerate(dataloader):
model.train() model.train()
g, train_nids, edges, labels = batch_data g, train_nids, edges, labels = batch_data
g = g.to(device) g = g.to(device)
train_nids = train_nids.to(device) train_nids = train_nids.to(device)
...@@ -90,47 +217,55 @@ def main(args): ...@@ -90,47 +217,55 @@ def main(args):
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
optimizer.step() optimizer.step()
print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr)) print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))
if (epoch + 1) % 500 == 0: if (epoch + 1) % 500 == 0:
# perform validation on CPU because full graph is too large # perform validation on CPU because full graph is too large
model = model.cpu() model = model.cpu()
model.eval() model.eval()
print("start eval")
embed = model(test_g, test_nids) embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets, mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol) batch_size=500)
# save best model # save best model
if best_mrr < mrr: if best_mrr < mrr:
best_mrr = mrr best_mrr = mrr
th.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file) torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
model = model.to(device) model = model.to(device)
print("Start testing:") if __name__ == '__main__':
# use best model checkpoint device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = th.load(model_state_file) print(f'Training with DGL built-in RGCN module')
# load and preprocess dataset
data = FB15k237Dataset(reverse=False)
g = data[0]
num_nodes = g.num_nodes()
num_rels = data.num_rels
train_g = get_subset_g(g, g.edata['train_mask'], num_rels)
test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True)
test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
test_nids = torch.arange(0, num_nodes)
test_mask = g.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
# Prepare data for metric computation
src, dst = g.edges()
triplets = torch.stack([src, g.edata['etype'], dst], dim=1)
# create RGCN model
model = LinkPredict(num_nodes, num_rels).to(device)
# train
model_state_file = 'model_state.pth'
train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model)
# testing
print("Testing...")
checkpoint = torch.load(model_state_file)
model = model.cpu() # test on CPU model = model.cpu() # test on CPU
model.eval() model.eval()
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_g, test_nids) embed = model(test_g, test_nids)
calc_mrr(embed, model.w_relation, test_mask, triplets, best_mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol) batch_size=500)
print("Best MRR {:.4f} achieved using the epoch {:04d}".format(best_mrr, checkpoint['epoch']))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for link prediction')
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--eval-protocol", type=str, default='filtered',
choices=['filtered', 'raw'],
help="Whether to use 'filtered' or 'raw' MRR for evaluation")
parser.add_argument("--edge-sampler", type=str, default='uniform',
choices=['uniform', 'neighbor'],
help="Type of edge sampler: 'uniform' or 'neighbor'"
"The original implementation uses neighbor sampler.")
args = parser.parse_args()
print(args)
main(args)
"""
Utility functions for link prediction
Most code is adapted from authors' implementation of RGCN link prediction:
https://github.com/MichSchli/RelationPrediction
"""
import numpy as np
import torch as th
import dgl
# Utility function for building training and testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
src, dst = g.edges()
sub_src = src[mask]
sub_dst = dst[mask]
sub_rel = g.edata['etype'][mask]
if bidirected:
sub_src, sub_dst = th.cat([sub_src, sub_dst]), th.cat([sub_dst, sub_src])
sub_rel = th.cat([sub_rel, sub_rel + num_rels])
sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
sub_g.edata[dgl.ETYPE] = sub_rel
return sub_g
def preprocess(g, num_rels):
# Get train graph
train_g = get_subset_g(g, g.edata['train_mask'], num_rels)
# Get test graph
test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True)
test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
return train_g, test_g
class GlobalUniform:
def __init__(self, g, sample_size):
self.sample_size = sample_size
self.eids = np.arange(g.num_edges())
def sample(self):
return th.from_numpy(np.random.choice(self.eids, self.sample_size))
class NeighborExpand:
"""Sample a connected component by neighborhood expansion"""
def __init__(self, g, sample_size):
self.g = g
self.nids = np.arange(g.num_nodes())
self.sample_size = sample_size
def sample(self):
edges = th.zeros((self.sample_size), dtype=th.int64)
neighbor_counts = (self.g.in_degrees() + self.g.out_degrees()).numpy()
seen_edge = np.array([False] * self.g.num_edges())
seen_node = np.array([False] * self.g.num_nodes())
for i in range(self.sample_size):
if np.sum(seen_node) == 0:
node_weights = np.ones_like(neighbor_counts)
node_weights[np.where(neighbor_counts == 0)] = 0
else:
# Sample a visited node if applicable.
# This guarantees a connected component.
node_weights = neighbor_counts * seen_node
node_probs = node_weights / np.sum(node_weights)
chosen_node = np.random.choice(self.nids, p=node_probs)
# Sample a neighbor of the sampled node
u1, v1, eid1 = self.g.in_edges(chosen_node, form='all')
u2, v2, eid2 = self.g.out_edges(chosen_node, form='all')
u = th.cat([u1, u2])
v = th.cat([v1, v2])
eid = th.cat([eid1, eid2])
to_pick = True
while to_pick:
random_id = th.randint(high=eid.shape[0], size=(1,))
chosen_eid = eid[random_id]
to_pick = seen_edge[chosen_eid]
chosen_u = u[random_id]
chosen_v = v[random_id]
edges[i] = chosen_eid
seen_node[chosen_u] = True
seen_node[chosen_v] = True
seen_edge[chosen_eid] = True
neighbor_counts[chosen_u] -= 1
neighbor_counts[chosen_v] -= 1
return edges
class NegativeSampler:
def __init__(self, k=10):
self.k = k
def sample(self, pos_samples, num_nodes):
batch_size = len(pos_samples)
neg_batch_size = batch_size * self.k
neg_samples = np.tile(pos_samples, (self.k, 1))
values = np.random.randint(num_nodes, size=neg_batch_size)
choices = np.random.uniform(size=neg_batch_size)
subj = choices > 0.5
obj = choices <= 0.5
neg_samples[subj, 0] = values[subj]
neg_samples[obj, 2] = values[obj]
samples = np.concatenate((pos_samples, neg_samples))
# binary labels indicating positive and negative samples
labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
labels[:batch_size] = 1
return th.from_numpy(samples), th.from_numpy(labels)
class SubgraphIterator:
def __init__(self, g, num_rels, pos_sampler, sample_size=30000, num_epochs=6000):
self.g = g
self.num_rels = num_rels
self.sample_size = sample_size
self.num_epochs = num_epochs
if pos_sampler == 'neighbor':
self.pos_sampler = NeighborExpand(g, sample_size)
else:
self.pos_sampler = GlobalUniform(g, sample_size)
self.neg_sampler = NegativeSampler()
def __len__(self):
return self.num_epochs
def __getitem__(self, i):
eids = self.pos_sampler.sample()
src, dst = self.g.find_edges(eids)
src, dst = src.numpy(), dst.numpy()
rel = self.g.edata[dgl.ETYPE][eids].numpy()
# relabel nodes to have consecutive node IDs
uniq_v, edges = np.unique((src, dst), return_inverse=True)
num_nodes = len(uniq_v)
# edges is the concatenation of src, dst with relabeled ID
src, dst = np.reshape(edges, (2, -1))
relabeled_data = np.stack((src, rel, dst)).transpose()
samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)
# Use only half of the positive edges
chosen_ids = np.random.choice(np.arange(self.sample_size),
size=int(self.sample_size / 2),
replace=False)
src = src[chosen_ids]
dst = dst[chosen_ids]
rel = rel[chosen_ids]
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + self.num_rels))
sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = th.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = th.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels
# Utility functions for evaluations (raw)
def perturb_and_get_raw_rank(emb, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets"""
n_batch = (test_size + batch_size - 1) // batch_size
ranks = []
emb = emb.transpose(0, 1) # size D x V
w = w.transpose(0, 1) # size D x R
for idx in range(n_batch):
print("batch {} / {}".format(idx, n_batch))
batch_start = idx * batch_size
batch_end = (idx + 1) * batch_size
batch_a = a[batch_start: batch_end]
batch_r = r[batch_start: batch_end]
emb_ar = emb[:,batch_a] * w[:,batch_r] # size D x E
emb_ar = emb_ar.unsqueeze(2) # size D x E x 1
emb_c = emb.unsqueeze(1) # size D x 1 x V
# out-prod and reduce sum
out_prod = th.bmm(emb_ar, emb_c) # size D x E x V
score = th.sum(out_prod, dim=0).sigmoid() # size E x V
target = b[batch_start: batch_end]
_, indices = th.sort(score, dim=1, descending=True)
indices = th.nonzero(indices == target.view(-1, 1), as_tuple=False)
ranks.append(indices[:, 1].view(-1))
return th.cat(ranks)
# Utility functions for evaluations (filtered)
def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
"""Get candidate heads or tails to score"""
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
# Add the ground truth node first
if filter_o:
candidate_nodes = [target_o]
else:
candidate_nodes = [target_s]
for e in range(num_nodes):
triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
# Do not consider a node if it leads to a real triplet
if triplet not in triplets_to_filter:
candidate_nodes.append(e)
return th.LongTensor(candidate_nodes)
def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
"""Perturb subject or object in the triplets"""
num_nodes = emb.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
candidate_nodes = filter(triplets_to_filter, target_s, target_r,
target_o, num_nodes, filter_o=filter_o)
if filter_o:
emb_s = emb[target_s]
emb_o = emb[candidate_nodes]
else:
emb_s = emb[candidate_nodes]
emb_o = emb[target_o]
target_idx = 0
emb_r = w[target_r]
emb_triplet = emb_s * emb_r * emb_o
scores = th.sigmoid(th.sum(emb_triplet, dim=1))
_, indices = th.sort(scores, descending=True)
rank = int((indices == target_idx).nonzero())
ranks.append(rank)
return th.LongTensor(ranks)
def _calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size, filter=False):
with th.no_grad():
test_triplets = triplets_to_filter[test_mask]
s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
test_size = len(s)
if filter:
metric_name = 'MRR (filtered)'
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
triplets_to_filter, filter_o=False)
ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
test_size, triplets_to_filter)
else:
metric_name = 'MRR (raw)'
ranks_s = perturb_and_get_raw_rank(emb, w, o, r, s, test_size, batch_size)
ranks_o = perturb_and_get_raw_rank(emb, w, s, r, o, test_size, batch_size)
ranks = th.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = th.mean(1.0 / ranks.float()).item()
print("{}: {:.6f}".format(metric_name, mrr))
return mrr
# Main evaluation function
def calc_mrr(emb, w, test_mask, triplets, batch_size=100, eval_p="filtered"):
if eval_p == "filtered":
mrr = _calc_mrr(emb, w, test_mask, triplets, batch_size, filter=True)
else:
mrr = _calc_mrr(emb, w, test_mask, triplets, batch_size)
return mrr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment