Unverified Commit e61138be authored by KounianhuaDu's avatar KounianhuaDu Committed by GitHub
Browse files

[Example] CompGCN (#2768)



* compgcn

* readme

* readme

* update

* readme
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent cfe6e70b
......@@ -85,6 +85,7 @@ The folder contains example implementations of selected research papers related
| [Directional Message Passing for Molecular Graphs](#dimenet) | | | :heavy_check_mark: | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |
| [Composition-based Multi-Relational Graph Convolutional Networks](#compgcn)| | :heavy_check_mark: | | | |
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
## 2020
......@@ -124,6 +125,10 @@ The folder contains example implementations of selected research papers related
- Example code: [Pytorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification
- <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).
- Example code: [Pytorch](../examples/pytorch/compGCN)
- Tags: multi-relational graphs, graph neural network
## 2019
......
# DGL Implementation of the CompGCN Paper
This DGL example implements the GNN model proposed in the paper [CompositionGCN](https://arxiv.org/abs/1911.03082).
The author's codes of implementation is in [here](https://github.com/malllabiisc/CompGCN)
Example implementor
----------------------
This example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [KounianhuaDu](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.
Dependencies
----------------------
- pytorch 1.7.1
- dgl 0.6.0
- numpy 1.19.4
- ordered_set 4.0.2
Dataset
---------------------------------------
The datasets used for link predictions are FB15k-237 constructed from Freebase and WN18RR constructed from WordNet. The statistics are summarized as followings:
**FB15k-237**
- Nodes: 14541
- Relation types: 237
- Reversed relation types: 237
- Train: 272115
- Valid: 17535
- Test: 20466
**WN18RR**
- Nodes: 40943
- Relation types: 11
- Reversed relation types: 11
- Train: 86835
- Valid: 3034
- Test: 3134
How to run
--------------------------------
First to get the data, one can run
```python
sh get_fb15k-237.sh
```
```python
sh get_wn18rr.sh
```
Then for FB15k-237, run
```python
python main.py --score_func conve --opn ccorr --gpu 0 --data FB15k-237
```
For WN18RR, run
```python
python main.py --score_func conve --opn ccorr --gpu 0 --data wn18rr
```
Performance
-------------------------
**Link Prediction Results**
| Dataset | FB15k-237 | WN18RR |
|---------| ------------------------ | ------------------------ |
| Metric | Paper / ours (dgl) | Paper / ours (dgl) |
| MRR | 0.355 / 0.349 | 0.479 / 0.471 |
| MR | 197 / 208 | 3533 / 3550 |
| Hit@10 | 0.535 / 0.526 | 0.546 / 0.532 |
| Hit@3 | 0.390 / 0.381 | 0.494 / 0.480 |
| Hit@1 | 0.264 / 0.260 | 0.443 / 0.438 |
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import dgl
from collections import defaultdict as ddict
from ordered_set import OrderedSet
class TrainDataset(Dataset):
"""
Training Dataset class.
Parameters
----------
triples: The triples used for training the model
num_ent: Number of entities in the knowledge graph
lbl_smooth: Label smoothing
Returns
-------
A training Dataset class instance used by DataLoader
"""
def __init__(self, triples, num_ent, lbl_smooth):
self.triples = triples
self.num_ent = num_ent
self.lbl_smooth = lbl_smooth
self.entities = np.arange(self.num_ent, dtype=np.int32)
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
trp_label = self.get_label(label)
#label smoothing
if self.lbl_smooth != 0.0:
trp_label = (1.0 - self.lbl_smooth) * trp_label + (1.0 / self.num_ent)
return triple, trp_label
@staticmethod
def collate_fn(data):
triples = []
labels = []
for triple, label in data:
triples.append(triple)
labels.append(label)
triple = torch.stack(triples, dim=0)
trp_label = torch.stack(labels, dim=0)
return triple, trp_label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
return torch.FloatTensor(y)
class TestDataset(Dataset):
"""
Evaluation Dataset class.
Parameters
----------
triples: The triples used for evaluating the model
num_ent: Number of entities in the knowledge graph
Returns
-------
An evaluation Dataset class instance used by DataLoader for model evaluation
"""
def __init__(self, triples, num_ent):
self.triples = triples
self.num_ent = num_ent
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
label = self.get_label(label)
return triple, label
@staticmethod
def collate_fn(data):
triples = []
labels = []
for triple, label in data:
triples.append(triple)
labels.append(label)
triple = torch.stack(triples, dim=0)
label = torch.stack(labels, dim=0)
return triple, label
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
def get_label(self, label):
y = np.zeros([self.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
return torch.FloatTensor(y)
class Data(object):
def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
"""
Reading in raw triples and converts it into a standard format.
Parameters
----------
dataset: The name of the dataset
lbl_smooth: Label smoothing
num_workers: Number of workers of dataloaders
batch_size: Batch size of dataloaders
Returns
-------
self.ent2id: Entity to unique identifier mapping
self.rel2id: Relation to unique identifier mapping
self.id2ent: Inverse mapping of self.ent2id
self.id2rel: Inverse mapping of self.rel2id
self.num_ent: Number of entities in the knowledge graph
self.num_rel: Number of relations in the knowledge graph
self.g: The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph
self.data['train']: Stores the triples corresponding to training dataset
self.data['valid']: Stores the triples corresponding to validation dataset
self.data['test']: Stores the triples corresponding to test dataset
self.data_iter: The dataloader for different data splits
"""
self.dataset = dataset
self.lbl_smooth = lbl_smooth
self.num_workers = num_workers
self.batch_size = batch_size
#read in raw data and get mappings
ent_set, rel_set = OrderedSet(), OrderedSet()
for split in ['train', 'test', 'valid']:
for line in open('./{}/{}.txt'.format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t'))
ent_set.add(sub)
rel_set.add(rel)
ent_set.add(obj)
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)})
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
self.num_ent = len(self.ent2id)
self.num_rel = len(self.rel2id) // 2
#read in ids of subjects, relations, and objects for train/test/valid
self.data = ddict(list) #stores the triples
sr2o = ddict(set) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
src=[]
dst=[]
rels = []
inver_src = []
inver_dst = []
inver_rels = []
for split in ['train', 'test', 'valid']:
for line in open('./{}/{}.txt'.format(self.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t'))
sub_id, rel_id, obj_id = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
self.data[split].append((sub_id, rel_id, obj_id))
if split == 'train':
sr2o[(sub_id, rel_id)].add(obj_id)
sr2o[(obj_id, rel_id+self.num_rel)].add(sub_id) #append the reversed edges
src.append(sub_id)
dst.append(obj_id)
rels.append(rel_id)
inver_src.append(obj_id)
inver_dst.append(sub_id)
inver_rels.append(rel_id+self.num_rel)
#construct dgl graph
src = src + inver_src
dst = dst + inver_dst
rels = rels + inver_rels
self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
self.g.edata['etype'] = torch.Tensor(rels).long()
#identify in and out edges
in_edges_mask = [True] * (self.g.num_edges()//2) + [False] * (self.g.num_edges()//2)
out_edges_mask = [False] * (self.g.num_edges()//2) + [True] * (self.g.num_edges()//2)
self.g.edata['in_edges_mask'] = torch.Tensor(in_edges_mask)
self.g.edata['out_edges_mask'] = torch.Tensor(out_edges_mask)
#Prepare train/valid/test data
self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()} #store only the train data
for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]:
sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel+self.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()} #store all the data
self.triples = ddict(list)
for (sub, rel), obj in self.sr2o.items():
self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)]})
for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]:
rel_inv = rel + self.num_rel
self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})
self.triples = dict(self.triples)
def get_train_data_loader(split, batch_size, shuffle=True):
return DataLoader(
TrainDataset(self.triples[split], self.num_ent, self.lbl_smooth),
batch_size = batch_size,
shuffle = shuffle,
num_workers = max(0, self.num_workers),
collate_fn = TrainDataset.collate_fn
)
def get_test_data_loader(split, batch_size, shuffle=True):
return DataLoader(
TestDataset(self.triples[split], self.num_ent),
batch_size = batch_size,
shuffle = shuffle,
num_workers = max(0, self.num_workers),
collate_fn = TestDataset.collate_fn
)
#train/valid/test dataloaders
self.data_iter = {
'train': get_train_data_loader('train', self.batch_size),
'valid_head': get_test_data_loader('valid_head', self.batch_size),
'valid_tail': get_test_data_loader('valid_tail', self.batch_size),
'test_head': get_test_data_loader('test_head', self.batch_size),
'test_tail': get_test_data_loader('test_tail', self.batch_size),
}
\ No newline at end of file
wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/FB15k-237.zip
unzip FB15k-237.zip
wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/wn18rr.zip
unzip wn18rr.zip
import argparse
import torch as th
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import dgl.function as fn
from utils import in_out_norm
from models import CompGCN_ConvE
from data_loader import Data
import numpy as np
from time import time
#predict the tail for (head, rel, -1) or head for (-1, rel, tail)
def predict(model, graph, device, data_iter, split='valid', mode='tail'):
model.eval()
with th.no_grad():
results = {}
train_iter = iter(data_iter['{}_{}'.format(split, mode)])
for step, batch in enumerate(train_iter):
triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label
pred = model(graph, sub, rel)
b_range = th.arange(pred.size()[0], device = device)
target_pred = pred[b_range, obj]
pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred
#compute metrics
ranks = 1 + th.argsort(th.argsort(pred, dim=1, descending=True), dim =1, descending=False)[b_range, obj]
ranks = ranks.float()
results['count'] = th.numel(ranks) + results.get('count', 0.0)
results['mr'] = th.sum(ranks).item() + results.get('mr', 0.0)
results['mrr'] = th.sum(1.0/ranks).item() + results.get('mrr', 0.0)
for k in [1,3,10]:
results['hits@{}'.format(k)] = th.numel(ranks[ranks <= (k)]) + results.get('hits@{}'.format(k), 0.0)
return results
#evaluation function, evaluate the head and tail prediction and then combine the results
def evaluate(model, graph, device, data_iter, split='valid'):
#predict for head and tail
left_results = predict(model, graph, device, data_iter, split, mode='tail')
right_results = predict(model, graph, device, data_iter, split, mode='head')
results = {}
count = float(left_results['count'])
#combine the head and tail prediction results
#Metrics: MRR, MR, and Hit@k
results['left_mr'] = round(left_results['mr']/count, 5)
results['left_mrr'] = round(left_results['mrr']/count, 5)
results['right_mr'] = round(right_results['mr']/count, 5)
results['right_mrr'] = round(right_results['mrr']/count, 5)
results['mr'] = round((left_results['mr'] + right_results['mr']) /(2*count), 5)
results['mrr'] = round((left_results['mrr'] + right_results['mrr']) /(2*count), 5)
for k in [1,3,10]:
results['left_hits@{}'.format(k)] = round(left_results['hits@{}'.format(k)]/count, 5)
results['right_hits@{}'.format(k)] = round(right_results['hits@{}'.format(k)]/count, 5)
results['hits@{}'.format(k)] = round((left_results['hits@{}'.format(k)] + right_results['hits@{}'.format(k)])/(2*count), 5)
return results
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
#construct graph, split in/out edges and prepare train/validation/test data_loader
data = Data(args.dataset, args.lbl_smooth, args.num_workers, args.batch_size)
data_iter = data.data_iter #train/validation/test data_loader
graph = data.g.to(device)
num_rel = th.max(graph.edata['etype']).item() + 1
#Compute in/out edge norms and store in edata
graph = in_out_norm(graph)
# Step 2: Create model =================================================================== #
compgcn_model=CompGCN_ConvE(num_bases=args.num_bases,
num_rel=num_rel,
num_ent=graph.num_nodes(),
in_dim=args.init_dim,
layer_size=args.layer_size,
comp_fn=args.opn,
batchnorm=True,
dropout=args.dropout,
layer_dropout=args.layer_dropout,
num_filt=args.num_filt,
hid_drop=args.hid_drop,
feat_drop=args.feat_drop,
ker_sz=args.ker_sz,
k_w=args.k_w,
k_h=args.k_h
)
compgcn_model = compgcn_model.to(device)
# Step 3: Create training components ===================================================== #
loss_fn = th.nn.BCELoss()
optimizer = optim.Adam(compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2)
# Step 4: training epoches =============================================================== #
best_mrr = 0.0
kill_cnt = 0
for epoch in range(args.max_epochs):
# Training and validation using a full graph
compgcn_model.train()
train_loss=[]
t0 = time()
for step, batch in enumerate(data_iter['train']):
triple, label = batch[0].to(device), batch[1].to(device)
sub, rel, obj, label = triple[:, 0], triple[:, 1], triple[:, 2], label
logits = compgcn_model(graph, sub, rel)
# compute loss
tr_loss = loss_fn(logits, label)
train_loss.append(tr_loss)
# backward
optimizer.zero_grad()
tr_loss.backward()
optimizer.step()
train_loss = np.sum(train_loss)
t1 = time()
val_results = evaluate(compgcn_model, graph, device, data_iter, split='valid')
t2 = time()
#validate
if val_results['mrr']>best_mrr:
best_mrr = val_results['mrr']
best_epoch = epoch
th.save(compgcn_model.state_dict(), 'comp_link'+'_'+args.dataset)
kill_cnt = 0
print("saving model...")
else:
kill_cnt += 1
if kill_cnt > 25:
print('early stop.')
break
print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\
.format(epoch, train_loss, val_results['mrr'], t1-t0, t2-t1))
#test use the best model
compgcn_model.eval()
compgcn_model.load_state_dict(th.load('comp_link'+'_'+args.dataset))
test_results = evaluate(compgcn_model, graph, device, data_iter, split='test')
print("Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n"\
.format(test_results['mrr'], test_results['mr'], test_results['hits@10'], test_results['hits@3'], test_results['hits@1']))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parser For Arguments', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data', dest='dataset', default='FB15k-237', help='Dataset to use, default: FB15k-237')
parser.add_argument('--model', dest='model', default='compgcn', help='Model Name')
parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction')
parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN')
parser.add_argument('--batch', dest='batch_size', default=128, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs')
parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer')
parser.add_argument('--lr', type=float, default=0.001, help='Starting Learning Rate')
parser.add_argument('--lbl_smooth', dest='lbl_smooth', type=float, default=0.1, help='Label Smoothing')
parser.add_argument('--num_workers', type=int, default=10, help='Number of processes to construct batches')
parser.add_argument('--seed', dest='seed', default=41504, type=int, help='Seed for randomization')
parser.add_argument('--num_bases', dest='num_bases', default=-1, type=int, help='Number of basis relation vectors to use')
parser.add_argument('--init_dim', dest='init_dim', default=100, type=int, help='Initial dimension size for entities and relations')
parser.add_argument('--layer_size', nargs='?', default='[200]', help='List of output size for each compGCN layer')
parser.add_argument('--gcn_drop', dest='dropout', default=0.1, type=float, help='Dropout to use in GCN Layer')
parser.add_argument('--layer_dropout', nargs='?', default='[0.3]', help='List of dropout value after each compGCN layer')
# ConvE specific hyperparameters
parser.add_argument('--hid_drop', dest='hid_drop', default=0.3, type=float, help='ConvE: Hidden dropout')
parser.add_argument('--feat_drop', dest='feat_drop', default=0.3, type=float, help='ConvE: Feature Dropout')
parser.add_argument('--k_w', dest='k_w', default=10, type=int, help='ConvE: k_w')
parser.add_argument('--k_h', dest='k_h', default=20, type=int, help='ConvE: k_h')
parser.add_argument('--num_filt', dest='num_filt', default=200, type=int, help='ConvE: Number of filters in convolution')
parser.add_argument('--ker_sz', dest='ker_sz', default=7, type=int, help='ConvE: Kernel size to use')
args = parser.parse_args()
np.random.seed(args.seed)
th.manual_seed(args.seed)
print(args)
args.layer_size = eval(args.layer_size)
args.layer_dropout = eval(args.layer_dropout)
main(args)
import torch as th
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import dgl
import dgl.function as fn
from utils import ccorr
class CompGraphConv(nn.Module):
"""One layer of CompGCN."""
def __init__(self,
in_dim,
out_dim,
comp_fn='sub',
batchnorm=True,
dropout=0.1):
super(CompGraphConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.comp_fn = comp_fn
self.actvation = th.tanh
self.batchnorm = batchnorm
# define dropout layer
self.dropout = nn.Dropout(dropout)
# define batch norm layer
if self.batchnorm:
self.bn = nn.BatchNorm1d(out_dim)
# define in/out/loop transform layer
self.W_O = nn.Linear(self.in_dim, self.out_dim)
self.W_I = nn.Linear(self.in_dim, self.out_dim)
self.W_S = nn.Linear(self.in_dim, self.out_dim)
# define relation transform layer
self.W_R = nn.Linear(self.in_dim, self.out_dim)
#self loop embedding
self.loop_rel = nn.Parameter(th.Tensor(1, self.in_dim))
nn.init.xavier_normal_(self.loop_rel)
def forward(self, g, n_in_feats, r_feats):
with g.local_scope():
# Assign values to source nodes. In a homogeneous graph, this is equal to
# assigning them to all nodes.
g.srcdata['h'] = n_in_feats
#append loop_rel embedding to r_feats
r_feats = th.cat((r_feats, self.loop_rel), 0)
# Assign features to all edges with the corresponding relation embeddings
g.edata['h'] = r_feats[g.edata['etype']] * g.edata['norm']
# Compute composition function in 4 steps
# Step 1: compute composition by edge in the edge direction, and store results in edges.
if self.comp_fn == 'sub':
g.apply_edges(fn.u_sub_e('h', 'h', out='comp_h'))
elif self.comp_fn == 'mul':
g.apply_edges(fn.u_mul_e('h', 'h', out='comp_h'))
elif self.comp_fn == 'ccorr':
g.apply_edges(lambda edges: {'comp_h': ccorr(edges.src['h'], edges.data['h'])})
else:
raise Exception('Only supports sub, mul, and ccorr')
# Step 2: use extracted edge direction to compute in and out edges
comp_h = g.edata['comp_h']
in_edges_idx = th.nonzero(g.edata['in_edges_mask'], as_tuple=False).squeeze()
out_edges_idx = th.nonzero(g.edata['out_edges_mask'], as_tuple=False).squeeze()
comp_h_O = self.W_O(comp_h[out_edges_idx])
comp_h_I = self.W_I(comp_h[in_edges_idx])
new_comp_h = th.zeros(comp_h.shape[0], self.out_dim).to(comp_h.device)
new_comp_h[out_edges_idx] = comp_h_O
new_comp_h[in_edges_idx] = comp_h_I
g.edata['new_comp_h'] = new_comp_h
# Step 3: sum comp results to both src and dst nodes
g.update_all(fn.copy_e('new_comp_h', 'm'), fn.sum('m', 'comp_edge'))
# Step 4: add results of self-loop
if self.comp_fn == 'sub':
comp_h_s = n_in_feats - r_feats[-1]
elif self.comp_fn == 'mul':
comp_h_s = n_in_feats * r_feats[-1]
elif self.comp_fn == 'ccorr':
comp_h_s = ccorr(n_in_feats, r_feats[-1])
else:
raise Exception('Only supports sub, mul, and ccorr')
# Sum all of the comp results as output of nodes and dropout
n_out_feats = (self.W_S(comp_h_s) + self.dropout(g.ndata['comp_edge'])) * (1/3)
# Compute relation output
r_out_feats = self.W_R(r_feats)
# Batch norm
if self.batchnorm:
n_out_feats = self.bn(n_out_feats)
# Activation function
if self.actvation is not None:
n_out_feats = self.actvation(n_out_feats)
return n_out_feats, r_out_feats[:-1]
class CompGCN(nn.Module):
def __init__(self,
num_bases,
num_rel,
num_ent,
in_dim=100,
layer_size=[200],
comp_fn='sub',
batchnorm=True,
dropout=0.1,
layer_dropout=[0.3]):
super(CompGCN, self).__init__()
self.num_bases = num_bases
self.num_rel = num_rel
self.num_ent = num_ent
self.in_dim = in_dim
self.layer_size = layer_size
self.comp_fn = comp_fn
self.batchnorm = batchnorm
self.dropout = dropout
self.layer_dropout = layer_dropout
self.num_layer = len(layer_size)
#CompGCN layers
self.layers = nn.ModuleList()
self.layers.append(
CompGraphConv(self.in_dim, self.layer_size[0], comp_fn = self.comp_fn, batchnorm=self.batchnorm, dropout=self.dropout)
)
for i in range(self.num_layer-1):
self.layers.append(
CompGraphConv(self.layer_size[i], self.layer_size[i+1], comp_fn = self.comp_fn, batchnorm=self.batchnorm, dropout=self.dropout)
)
#Initial relation embeddings
if self.num_bases > 0:
self.basis = nn.Parameter(th.Tensor(self.num_bases, self.in_dim))
self.weights = nn.Parameter(th.Tensor(self.num_rel, self.num_bases))
nn.init.xavier_normal_(self.basis)
nn.init.xavier_normal_(self.weights)
else:
self.rel_embds = nn.Parameter(th.Tensor(self.num_rel, self.in_dim))
nn.init.xavier_normal_(self.rel_embds)
#Node embeddings
self.n_embds = nn.Parameter(th.Tensor(self.num_ent, self.in_dim))
nn.init.xavier_normal_(self.n_embds)
#Dropout after compGCN layers
self.dropouts = nn.ModuleList()
for i in range(self.num_layer):
self.dropouts.append(
nn.Dropout(self.layer_dropout[i])
)
def forward(self, graph):
#node and relation features
n_feats = self.n_embds
if self.num_bases > 0:
r_embds = th.mm(self.weights, self.basis)
r_feats = r_embds
else:
r_feats = self.rel_embds
for layer, dropout in zip(self.layers, self.dropouts):
n_feats, r_feats = layer(graph, n_feats, r_feats)
n_feats = dropout(n_feats)
return n_feats, r_feats
#Use convE as the score function
class CompGCN_ConvE(nn.Module):
def __init__(self,
num_bases,
num_rel,
num_ent,
in_dim,
layer_size,
comp_fn='sub',
batchnorm=True,
dropout=0.1,
layer_dropout=[0.3],
num_filt=200,
hid_drop=0.3,
feat_drop=0.3,
ker_sz=5,
k_w=5,
k_h=5
):
super(CompGCN_ConvE, self).__init__()
self.embed_dim = layer_size[-1]
self.hid_drop=hid_drop
self.feat_drop=feat_drop
self.ker_sz=ker_sz
self.k_w=k_w
self.k_h=k_h
self.num_filt=num_filt
#compGCN model to get sub/rel embs
self.compGCN_Model = CompGCN(num_bases, num_rel, num_ent, in_dim, layer_size, comp_fn, batchnorm, dropout, layer_dropout)
#batchnorms to the combined (sub+rel) emb
self.bn0 = th.nn.BatchNorm2d(1)
self.bn1 = th.nn.BatchNorm2d(self.num_filt)
self.bn2 = th.nn.BatchNorm1d(self.embed_dim)
#dropouts and conv module to the combined (sub+rel) emb
self.hidden_drop = th.nn.Dropout(self.hid_drop)
self.feature_drop = th.nn.Dropout(self.feat_drop)
self.m_conv1 = th.nn.Conv2d(1, out_channels=self.num_filt, kernel_size=(self.ker_sz, self.ker_sz), stride=1, padding=0, bias=False)
flat_sz_h = int(2 * self.k_w) - self.ker_sz + 1
flat_sz_w = self.k_h - self.ker_sz + 1
self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
self.fc = th.nn.Linear(self.flat_sz, self.embed_dim)
#bias to the score
self.bias = nn.Parameter(th.zeros(num_ent))
#combine entity embeddings and relation embeddings
def concat(self, e1_embed, rel_embed):
e1_embed = e1_embed.view(-1, 1, self.embed_dim)
rel_embed = rel_embed.view(-1, 1, self.embed_dim)
stack_inp = th.cat([e1_embed, rel_embed], 1)
stack_inp = th.transpose(stack_inp, 2, 1).reshape((-1, 1, 2 * self.k_w, self.k_h))
return stack_inp
def forward(self, graph, sub, rel):
#get sub_emb and rel_emb via compGCN
n_feats, r_feats = self.compGCN_Model(graph)
sub_emb = n_feats[sub, :]
rel_emb = r_feats[rel, :]
#combine the sub_emb and rel_emb
stk_inp = self.concat(sub_emb, rel_emb)
#use convE to score the combined emb
x = self.bn0(stk_inp)
x = self.m_conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_drop(x)
x = x.view(-1, self.flat_sz)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
#compute score
x = th.mm(x, n_feats.transpose(1,0))
#add in bias
x += self.bias.expand_as(x)
score = th.sigmoid(x)
return score
# This file is based on the CompGCN author's implementation
# <https://github.com/malllabiisc/CompGCN/blob/master/helper.py>.
# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.
import torch as th
import dgl
def com_mult(a, b):
r1, i1 = a[..., 0], a[..., 1]
r2, i2 = b[..., 0], b[..., 1]
return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)
def conj(a):
a[..., 1] = -a[..., 1]
return a
def ccorr(a, b):
"""
Compute circular correlation of two tensors.
Parameters
----------
a: Tensor, 1D or 2D
b: Tensor, 1D or 2D
Notes
-----
Input a and b should have the same dimensions. And this operation supports broadcasting.
Returns
-------
Tensor, having the same dimension as the input a.
"""
return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))
#identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph):
src, dst, EID = graph.edges(form='all')
graph.edata['norm'] = th.ones(EID.shape[0]).to(graph.device)
in_edges_idx = th.nonzero(graph.edata['in_edges_mask'], as_tuple=False).squeeze()
out_edges_idx = th.nonzero(graph.edata['out_edges_mask'], as_tuple=False).squeeze()
for idx in [in_edges_idx, out_edges_idx]:
u, v = src[idx], dst[idx]
deg = th.zeros(graph.num_nodes()).to(graph.device)
n_idx, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
deg[n_idx]=count.float()
deg_inv = deg.pow(-0.5) # D^{-0.5}
deg_inv[deg_inv == float('inf')] = 0
norm = deg_inv[u] * deg_inv[v]
graph.edata['norm'][idx] = norm
graph.edata['norm'] = graph.edata['norm'].unsqueeze(1)
return graph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment