Unverified Commit 72cfb934 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] RGCN link prediction example refactor (#4548)

parent 14cafe73
......@@ -15,7 +15,7 @@ class GAT(nn.Module):
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], hid_size, heads[1], residual=True, activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[1], out_size, heads[2], residual=True, activation=None))
def forward(self, g, inputs):
h = inputs
for i, layer in enumerate(self.gat_layers):
......@@ -43,13 +43,13 @@ def evaluate_in_batches(dataloader, device, model):
score = evaluate(batched_graph, features, labels, model)
total_score += score
return total_score / (batch_id + 1) # return average score
def train(train_dataloader, val_dataloader, device, model):
# define loss function and optimizer
loss_fcn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)
# training loop
# training loop
for epoch in range(400):
model.train()
logits = []
......@@ -66,27 +66,27 @@ def train(train_dataloader, val_dataloader, device, model):
optimizer.step()
total_loss += loss.item()
print("Epoch {:05d} | Loss {:.4f} |". format(epoch, total_loss / (batch_id + 1) ))
if (epoch + 1) % 5 == 0:
avg_score = evaluate_in_batches(val_dataloader, device, model) # evaluate F1-score instead of loss
print(" Acc. (F1-score) {:.4f} ". format(avg_score))
if __name__ == '__main__':
print(f'Training PPI Dataset with DGL built-in GATConv module.')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load and preprocess datasets
train_dataset = PPIDataset(mode='train')
val_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
features = train_dataset[0].ndata['feat']
# create GAT model
# create GAT model
in_size = features.shape[1]
out_size = train_dataset.num_labels
model = GAT(in_size, 256, out_size, heads=[4,4,6]).to(device)
# model training
print('Training...')
train_dataloader = GraphDataLoader(train_dataset, batch_size=2)
......
......@@ -25,9 +25,9 @@ python3 train.py --dataset MUTAG
Summary (10-fold cross-validation)
-------
| Dataset | Result
| Dataset | Result
| ------------- | -------
| MUTAG | ~89.4
| PTC | ~68.5
| NCI1 | ~82.9
| PROTEINS | ~74.1
| MUTAG | ~89.4
| PTC | ~68.5
| NCI1 | ~82.9
| PROTEINS | ~74.1
......@@ -57,7 +57,7 @@ class SAGE(nn.Module):
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y
return y
def evaluate(model, graph, dataloader):
model.eval()
ys = []
......@@ -96,7 +96,7 @@ def train(args, device, g, dataset, model):
use_uva=use_uva)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10):
model.train()
total_loss = 0
......@@ -122,7 +122,7 @@ if __name__ == '__main__':
if not torch.cuda.is_available():
args.mode = 'cpu'
print(f'Training in {args.mode} mode.')
# load and preprocess dataset
print('Loading data')
dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
......
......@@ -28,23 +28,22 @@ For mini-batch training, run with the following (available datasets are the same
```bash
python3 entity_sample.py --dataset aifb
```
For multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma)
For multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma)
```bash
python3 entity_sample_multi_gpu.py --dataset aifb --gpu 0,1
```
### Link Prediction
FB15k-237 in RAW-MRR
```
python link.py --gpu 0 --eval-protocol raw
```
FB15k-237 in Filtered-MRR
```
python link.py --gpu 0 --eval-protocol filtered
Run with the following for link prediction on dataset FB15k-237 with filtered-MRR
```bash
python link.py
```
> **_NOTE:_** By default, we use uniform edge sampling instead of neighbor-based edge sampling as in [author's code](https://github.com/MichSchli/RelationPrediction). In practice, we find that it can achieve similar MRR.
Summary
Summary
-------
### Entity Classification
......@@ -55,3 +54,8 @@ Summary
| mutag | ~0.70 | ~0.50
| bgs | ~0.86 | ~0.64
| am | ~0.78 | ~0.42
### Link Prediction
| Dataset | Best MRR
| ------------- | -------
| FB15k-237 | ~0.2439
......@@ -16,13 +16,13 @@ class RGCN(nn.Module):
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb.weight
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return h
def evaluate(g, target_idx, labels, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval()
......
......@@ -17,13 +17,13 @@ class RGCN(nn.Module):
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
def evaluate(model, label, dataloader, inv_target):
model.eval()
eval_logits = []
......@@ -68,7 +68,7 @@ def train(device, g, target_idx, labels, train_mask, model):
acc = evaluate(model, labels, val_loader, inv_target)
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
. format(epoch, total_loss / (it+1), acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
parser.add_argument("--dataset", type=str, default="aifb",
......@@ -105,12 +105,12 @@ if __name__ == '__main__':
# find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)
# create RGCN model
# create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
......
"""
Differences compared to MichSchli/RelationPrediction
* Report raw metrics instead of filtered metrics.
* By default, we use uniform edge sampling instead of neighbor-based edge
sampling used in author's code. In practice, we find it achieves similar MRR.
"""
import argparse
import torch as th
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import RelGraphConv
import tqdm
# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
src, dst = g.edges()
sub_src = src[mask]
sub_dst = dst[mask]
sub_rel = g.edata['etype'][mask]
if bidirected:
sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat([sub_dst, sub_src])
sub_rel = torch.cat([sub_rel, sub_rel + num_rels])
sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
sub_g.edata[dgl.ETYPE] = sub_rel
return sub_g
class GlobalUniform:
def __init__(self, g, sample_size):
self.sample_size = sample_size
self.eids = np.arange(g.num_edges())
def sample(self):
return torch.from_numpy(np.random.choice(self.eids, self.sample_size))
class NegativeSampler:
def __init__(self, k=10): # negative sampling rate = 10
self.k = k
def sample(self, pos_samples, num_nodes):
batch_size = len(pos_samples)
neg_batch_size = batch_size * self.k
neg_samples = np.tile(pos_samples, (self.k, 1))
values = np.random.randint(num_nodes, size=neg_batch_size)
choices = np.random.uniform(size=neg_batch_size)
subj = choices > 0.5
obj = choices <= 0.5
neg_samples[subj, 0] = values[subj]
neg_samples[obj, 2] = values[obj]
samples = np.concatenate((pos_samples, neg_samples))
# binary labels indicating positive and negative samples
labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
labels[:batch_size] = 1
return torch.from_numpy(samples), torch.from_numpy(labels)
class SubgraphIterator:
def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
self.g = g
self.num_rels = num_rels
self.sample_size = sample_size
self.num_epochs = num_epochs
self.pos_sampler = GlobalUniform(g, sample_size)
self.neg_sampler = NegativeSampler()
from link_utils import preprocess, SubgraphIterator, calc_mrr
from model import RGCN
def __len__(self):
return self.num_epochs
def __getitem__(self, i):
eids = self.pos_sampler.sample()
src, dst = self.g.find_edges(eids)
src, dst = src.numpy(), dst.numpy()
rel = self.g.edata[dgl.ETYPE][eids].numpy()
# relabel nodes to have consecutive node IDs
uniq_v, edges = np.unique((src, dst), return_inverse=True)
num_nodes = len(uniq_v)
# edges is the concatenation of src, dst with relabeled ID
src, dst = np.reshape(edges, (2, -1))
relabeled_data = np.stack((src, rel, dst)).transpose()
samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)
# use only half of the positive edges
chosen_ids = np.random.choice(np.arange(self.sample_size),
size=int(self.sample_size / 2),
replace=False)
src = src[chosen_ids]
dst = dst[chosen_ids]
rel = rel[chosen_ids]
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + self.num_rels))
sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = torch.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels
class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, num_rels):
super().__init__()
# two-layer RGCN
self.emb = nn.Embedding(num_nodes, h_dim)
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
num_bases=100, self_loop=True)
self.conv2 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
num_bases=100, self_loop=True)
self.dropout = nn.Dropout(0.2)
def forward(self, g, nids):
x = self.emb(nids)
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
h = self.dropout(h)
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return self.dropout(h)
class LinkPredict(nn.Module):
def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
super(LinkPredict, self).__init__()
self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
num_bases=num_bases, dropout=dropout, self_loop=True)
self.dropout = nn.Dropout(dropout)
def __init__(self, num_nodes, num_rels, h_dim = 500, reg_param=0.01):
super().__init__()
self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
self.reg_param = reg_param
self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation,
gain=nn.init.calculate_gain('relu'))
def calc_score(self, embedding, triplets):
# DistMult
s = embedding[triplets[:,0]]
r = self.w_relation[triplets[:,1]]
o = embedding[triplets[:,2]]
score = th.sum(s * r * o, dim=1)
score = torch.sum(s * r * o, dim=1)
return score
def forward(self, g, nids):
return self.dropout(self.rgcn(g, nids=nids))
return self.rgcn(g, nids)
def regularization_loss(self, embedding):
return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))
return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))
def get_loss(self, embed, triplets, labels):
# each row in the triplets is a 3-tuple of (source, relation, destination)
......@@ -48,36 +144,67 @@ class LinkPredict(nn.Module):
reg_loss = self.regularization_loss(embed)
return predict_loss + self.reg_param * reg_loss
def main(args):
data = FB15k237Dataset(reverse=False)
graph = data[0]
num_nodes = graph.num_nodes()
num_rels = data.num_rels
train_g, test_g = preprocess(graph, num_rels)
test_nids = th.arange(0, num_nodes)
test_mask = graph.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
"""Get candidate heads or tails to score"""
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
# Add the ground truth node first
if filter_o:
candidate_nodes = [target_o]
else:
candidate_nodes = [target_s]
for e in range(num_nodes):
triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
# Do not consider a node if it leads to a real triplet
if triplet not in triplets_to_filter:
candidate_nodes.append(e)
return torch.LongTensor(candidate_nodes)
# Prepare data for metric computation
src, dst = graph.edges()
triplets = th.stack([src, graph.edata['etype'], dst], dim=1)
def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
"""Perturb subject or object in the triplets"""
num_nodes = emb.shape[0]
ranks = []
for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
candidate_nodes = filter(triplets_to_filter, target_s, target_r,
target_o, num_nodes, filter_o=filter_o)
if filter_o:
emb_s = emb[target_s]
emb_o = emb[candidate_nodes]
else:
emb_s = emb[candidate_nodes]
emb_o = emb[target_o]
target_idx = 0
emb_r = w[target_r]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
model = LinkPredict(num_nodes, num_rels)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)
if args.gpu >= 0 and th.cuda.is_available():
device = th.device(args.gpu)
else:
device = th.device('cpu')
model = model.to(device)
def calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True):
with torch.no_grad():
test_triplets = triplets_to_filter[test_mask]
s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
test_size = len(s)
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
triplets_to_filter, filter_o=False)
ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
test_size, triplets_to_filter)
ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float()).item()
return mrr
def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
best_mrr = 0
model_state_file = 'model_state.pth'
for epoch, batch_data in enumerate(dataloader):
for epoch, batch_data in enumerate(dataloader): # single graph batch
model.train()
g, train_nids, edges, labels = batch_data
g = g.to(device)
train_nids = train_nids.to(device)
......@@ -90,47 +217,55 @@ def main(args):
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
optimizer.step()
print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))
if (epoch + 1) % 500 == 0:
# perform validation on CPU because full graph is too large
model = model.cpu()
model.eval()
print("start eval")
embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
batch_size=500)
# save best model
if best_mrr < mrr:
best_mrr = mrr
th.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
model = model.to(device)
print("Start testing:")
# use best model checkpoint
checkpoint = th.load(model_state_file)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training with DGL built-in RGCN module')
# load and preprocess dataset
data = FB15k237Dataset(reverse=False)
g = data[0]
num_nodes = g.num_nodes()
num_rels = data.num_rels
train_g = get_subset_g(g, g.edata['train_mask'], num_rels)
test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True)
test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
test_nids = torch.arange(0, num_nodes)
test_mask = g.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
# Prepare data for metric computation
src, dst = g.edges()
triplets = torch.stack([src, g.edata['etype'], dst], dim=1)
# create RGCN model
model = LinkPredict(num_nodes, num_rels).to(device)
# train
model_state_file = 'model_state.pth'
train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model)
# testing
print("Testing...")
checkpoint = torch.load(model_state_file)
model = model.cpu() # test on CPU
model.eval()
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_g, test_nids)
calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for link prediction')
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--eval-protocol", type=str, default='filtered',
choices=['filtered', 'raw'],
help="Whether to use 'filtered' or 'raw' MRR for evaluation")
parser.add_argument("--edge-sampler", type=str, default='uniform',
choices=['uniform', 'neighbor'],
help="Type of edge sampler: 'uniform' or 'neighbor'"
"The original implementation uses neighbor sampler.")
args = parser.parse_args()
print(args)
main(args)
best_mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500)
print("Best MRR {:.4f} achieved using the epoch {:04d}".format(best_mrr, checkpoint['epoch']))
"""
Utility functions for link prediction
Most code is adapted from authors' implementation of RGCN link prediction:
https://github.com/MichSchli/RelationPrediction
"""
import numpy as np
import torch as th
import dgl
# Utility function for building training and testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
src, dst = g.edges()
sub_src = src[mask]
sub_dst = dst[mask]
sub_rel = g.edata['etype'][mask]
if bidirected:
sub_src, sub_dst = th.cat([sub_src, sub_dst]), th.cat([sub_dst, sub_src])
sub_rel = th.cat([sub_rel, sub_rel + num_rels])
sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
sub_g.edata[dgl.ETYPE] = sub_rel
return sub_g
def preprocess(g, num_rels):
# Get train graph
train_g = get_subset_g(g, g.edata['train_mask'], num_rels)
# Get test graph
test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True)
test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
return train_g, test_g
class GlobalUniform:
def __init__(self, g, sample_size):
self.sample_size = sample_size
self.eids = np.arange(g.num_edges())
def sample(self):
return th.from_numpy(np.random.choice(self.eids, self.sample_size))
class NeighborExpand:
"""Sample a connected component by neighborhood expansion"""
def __init__(self, g, sample_size):
self.g = g
self.nids = np.arange(g.num_nodes())
self.sample_size = sample_size
def sample(self):
edges = th.zeros((self.sample_size), dtype=th.int64)
neighbor_counts = (self.g.in_degrees() + self.g.out_degrees()).numpy()
seen_edge = np.array([False] * self.g.num_edges())
seen_node = np.array([False] * self.g.num_nodes())
for i in range(self.sample_size):
if np.sum(seen_node) == 0:
node_weights = np.ones_like(neighbor_counts)
node_weights[np.where(neighbor_counts == 0)] = 0
else:
# Sample a visited node if applicable.
# This guarantees a connected component.
node_weights = neighbor_counts * seen_node
node_probs = node_weights / np.sum(node_weights)
chosen_node = np.random.choice(self.nids, p=node_probs)
# Sample a neighbor of the sampled node
u1, v1, eid1 = self.g.in_edges(chosen_node, form='all')
u2, v2, eid2 = self.g.out_edges(chosen_node, form='all')
u = th.cat([u1, u2])
v = th.cat([v1, v2])
eid = th.cat([eid1, eid2])
to_pick = True
while to_pick:
random_id = th.randint(high=eid.shape[0], size=(1,))
chosen_eid = eid[random_id]
to_pick = seen_edge[chosen_eid]
chosen_u = u[random_id]
chosen_v = v[random_id]
edges[i] = chosen_eid
seen_node[chosen_u] = True
seen_node[chosen_v] = True
seen_edge[chosen_eid] = True
neighbor_counts[chosen_u] -= 1
neighbor_counts[chosen_v] -= 1
return edges
class NegativeSampler:
def __init__(self, k=10):
self.k = k
def sample(self, pos_samples, num_nodes):
batch_size = len(pos_samples)
neg_batch_size = batch_size * self.k
neg_samples = np.tile(pos_samples, (self.k, 1))
values = np.random.randint(num_nodes, size=neg_batch_size)
choices = np.random.uniform(size=neg_batch_size)
subj = choices > 0.5
obj = choices <= 0.5
neg_samples[subj, 0] = values[subj]
neg_samples[obj, 2] = values[obj]
samples = np.concatenate((pos_samples, neg_samples))
# binary labels indicating positive and negative samples
labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
labels[:batch_size] = 1
return th.from_numpy(samples), th.from_numpy(labels)
class SubgraphIterator:
def __init__(self, g, num_rels, pos_sampler, sample_size=30000, num_epochs=6000):
self.g = g
self.num_rels = num_rels
self.sample_size = sample_size
self.num_epochs = num_epochs
if pos_sampler == 'neighbor':
self.pos_sampler = NeighborExpand(g, sample_size)
else:
self.pos_sampler = GlobalUniform(g, sample_size)
self.neg_sampler = NegativeSampler()
def __len__(self):
return self.num_epochs
def __getitem__(self, i):
eids = self.pos_sampler.sample()
src, dst = self.g.find_edges(eids)
src, dst = src.numpy(), dst.numpy()
rel = self.g.edata[dgl.ETYPE][eids].numpy()
# relabel nodes to have consecutive node IDs
uniq_v, edges = np.unique((src, dst), return_inverse=True)
num_nodes = len(uniq_v)
# edges is the concatenation of src, dst with relabeled ID
src, dst = np.reshape(edges, (2, -1))
relabeled_data = np.stack((src, rel, dst)).transpose()
samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)
# Use only half of the positive edges
chosen_ids = np.random.choice(np.arange(self.sample_size),
size=int(self.sample_size / 2),
replace=False)
src = src[chosen_ids]
dst = dst[chosen_ids]
rel = rel[chosen_ids]
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + self.num_rels))
sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = th.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = th.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels
# Utility functions for evaluations (raw)
def perturb_and_get_raw_rank(emb, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets"""
n_batch = (test_size + batch_size - 1) // batch_size
ranks = []
emb = emb.transpose(0, 1) # size D x V
w = w.transpose(0, 1) # size D x R
for idx in range(n_batch):
print("batch {} / {}".format(idx, n_batch))
batch_start = idx * batch_size
batch_end = (idx + 1) * batch_size
batch_a = a[batch_start: batch_end]
batch_r = r[batch_start: batch_end]
emb_ar = emb[:,batch_a] * w[:,batch_r] # size D x E
emb_ar = emb_ar.unsqueeze(2) # size D x E x 1
emb_c = emb.unsqueeze(1) # size D x 1 x V
# out-prod and reduce sum
out_prod = th.bmm(emb_ar, emb_c) # size D x E x V
score = th.sum(out_prod, dim=0).sigmoid() # size E x V
target = b[batch_start: batch_end]
_, indices = th.sort(score, dim=1, descending=True)
indices = th.nonzero(indices == target.view(-1, 1), as_tuple=False)
ranks.append(indices[:, 1].view(-1))
return th.cat(ranks)
# Utility functions for evaluations (filtered)
def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
"""Get candidate heads or tails to score"""
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
# Add the ground truth node first
if filter_o:
candidate_nodes = [target_o]
else:
candidate_nodes = [target_s]
for e in range(num_nodes):
triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
# Do not consider a node if it leads to a real triplet
if triplet not in triplets_to_filter:
candidate_nodes.append(e)
return th.LongTensor(candidate_nodes)
def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
"""Perturb subject or object in the triplets"""
num_nodes = emb.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
candidate_nodes = filter(triplets_to_filter, target_s, target_r,
target_o, num_nodes, filter_o=filter_o)
if filter_o:
emb_s = emb[target_s]
emb_o = emb[candidate_nodes]
else:
emb_s = emb[candidate_nodes]
emb_o = emb[target_o]
target_idx = 0
emb_r = w[target_r]
emb_triplet = emb_s * emb_r * emb_o
scores = th.sigmoid(th.sum(emb_triplet, dim=1))
_, indices = th.sort(scores, descending=True)
rank = int((indices == target_idx).nonzero())
ranks.append(rank)
return th.LongTensor(ranks)
def _calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size, filter=False):
with th.no_grad():
test_triplets = triplets_to_filter[test_mask]
s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
test_size = len(s)
if filter:
metric_name = 'MRR (filtered)'
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
triplets_to_filter, filter_o=False)
ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
test_size, triplets_to_filter)
else:
metric_name = 'MRR (raw)'
ranks_s = perturb_and_get_raw_rank(emb, w, o, r, s, test_size, batch_size)
ranks_o = perturb_and_get_raw_rank(emb, w, s, r, o, test_size, batch_size)
ranks = th.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
mrr = th.mean(1.0 / ranks.float()).item()
print("{}: {:.6f}".format(metric_name, mrr))
return mrr
# Main evaluation function
def calc_mrr(emb, w, test_mask, triplets, batch_size=100, eval_p="filtered"):
if eval_p == "filtered":
mrr = _calc_mrr(emb, w, test_mask, triplets, batch_size, filter=True)
else:
mrr = _calc_mrr(emb, w, test_mask, triplets, batch_size)
return mrr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment