Commit 02fb0581 authored by Chao Ma's avatar Chao Ma Committed by Da Zheng
Browse files

[KG] Add RESCAL model to DGL-KGE (#923)

* Add RESCAL model

* update

* update

* match acc

* update

* add README.md

* fix
parent 20439e1c
......@@ -18,6 +18,7 @@ DGL-KE includes the following knowledge graph embedding models:
- TransE
- DistMult
- ComplEx
- RESCAL
It will add other popular models in the future.
......@@ -57,10 +58,10 @@ The speed is measured with 16 CPU cores and one Nvidia V100 GPU.
The speed on FB15k
| Models | TransE | DistMult | ComplEx |
|---------|--------|----------|---------|
|MAX_STEPS| 20000 | 100000 | 100000 |
|TIME | 411s | 690s | 806s |
| Models | TransE | DistMult | ComplEx | RESCAL |
|---------|--------|----------|---------|--------|
|MAX_STEPS| 20000 | 100000 | 100000 | 30000 |
|TIME | 411s | 690s | 806s | 1800s |
The accuracy on FB15k
......@@ -69,15 +70,16 @@ The accuracy on FB15k
| TransE | 69.12 | 0.656 | 0.567 | 0.718 | 0.802 |
| DistMult | 43.35 | 0.783 | 0.713 | 0.837 | 0.897 |
| ComplEx | 51.99 | 0.785 | 0.720 | 0.832 | 0.889 |
| RESCAL | 130.89| 0.668 | 0.597 | 0.720 | 0.800 |
In comparison, GraphVite uses 4 GPUs and takes 14 minutes. Thus, DGL-KE trains TransE on FB15k twice as fast as GraphVite while using much few resources. More performance information on GraphVite can be found [here](https://github.com/DeepGraphLearning/graphvite).
The speed on wn18
| Models | TransE | DistMult | ComplEx |
|---------|--------|----------|---------|
|MAX_STEPS| 40000 | 10000 | 20000 |
|TIME | 719s | 126s | 266s |
| Models | TransE | DistMult | ComplEx | RESCAL |
|---------|--------|----------|---------|--------|
|MAX_STEPS| 40000 | 10000 | 20000 | 20000 |
|TIME | 719s | 126s | 266s | 333s |
The accuracy on wn18
......@@ -86,6 +88,7 @@ The accuracy on wn18
| TransE | 321.35 | 0.760 | 0.652 | 0.850 | 0.940 |
| DistMult | 271.09 | 0.769 | 0.639 | 0.892 | 0.949 |
| ComplEx | 276.37 | 0.935 | 0.916 | 0.950 | 0.960 |
| RESCAL | 579.54 | 0.846 | 0.791 | 0.898 | 0.931 |
The speed on Freebase
......
......@@ -14,6 +14,10 @@ DGLBACKEND=pytorch python3 train.py --model TransE --dataset FB15k --batch_size
--neg_sample_size 256 --hidden_dim 2000 --gamma 24.0 --lr 0.01 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
DGLBACKEND=pytorch python3 train.py --model RESCAL --dataset FB15k --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 500 --gamma 24.0 --lr 0.03 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
# for wn18
DGLBACKEND=pytorch python3 train.py --model TransE --dataset wn18 --batch_size 1024 \
......@@ -29,6 +33,10 @@ DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset wn18 --batch_size
--neg_sample_size 1024 --hidden_dim 500 --gamma 200.0 --lr 0.1 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --regularization_coef 0.00001
DGLBACKEND=pytorch python3 train.py --model RESCAL --dataset wn18 --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 250 --gamma 24.0 --lr 0.03 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
# for Freebase
DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset Freebase --batch_size 1024 \
......
......@@ -52,6 +52,9 @@ class KEModel(object):
self.score_func = DistMultScore()
elif model_name == 'ComplEx':
self.score_func = ComplExScore()
elif model_name == 'RESCAL':
self.score_func = RESCALScore(relation_dim, entity_dim)
self.head_neg_score = self.score_func.create_neg(True)
self.tail_neg_score = self.score_func.create_neg(False)
......
......@@ -144,3 +144,55 @@ class ComplExScore(nn.Block):
tails = nd.transpose(tails, axes=(0, 2, 1))
return nd.linalg_gemm2(tmp, tails)
return fn
class RESCALScore(nn.Block):
def __init__(self, relation_dim, entity_dim):
super(RESCALScore, self).__init__()
self.relation_dim = relation_dim
self.entity_dim = entity_dim
def edge_func(self, edges):
head = edges.src['emb']
tail = edges.dst['emb'].expand_dims(2)
rel = edges.data['emb']
rel = rel.reshape(-1, self.relation_dim, self.entity_dim)
score = head * mx.nd.batch_dot(rel, tail).squeeze()
# TODO: check if use self.gamma
return {'score': mx.nd.sum(score, -1)}
# return {'score': self.gamma - th.norm(score, p=1, dim=-1)}
def reset_parameters(self):
pass
def save(self, path, name):
pass
def load(self, path, name):
pass
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def create_neg(self, neg_head):
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
heads = mx.nd.transpose(heads, axes=(0,2,1))
tails = tails.expand_dims(2)
relations = relations.reshape(-1, self.relation_dim, self.entity_dim)
tmp = mx.nd.batch_dot(relations, tails).squeeze()
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return nd.linalg_gemm2(tmp, heads)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
tails = mx.nd.transpose(tails, axes=(0,2,1))
heads = heads.expand_dims(2)
relations = relations.reshape(-1, self.relation_dim, self.entity_dim)
tmp = mx.nd.batch_dot(relations, heads).squeeze()
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return nd.linalg_gemm2(tmp, tails)
return fn
......@@ -175,3 +175,27 @@ class RESCALScore(nn.Module):
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def create_neg(self, neg_head):
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
heads = th.transpose(heads, 1, 2)
tails = tails.unsqueeze(-1)
relations = relations.view(-1, self.relation_dim, self.entity_dim)
tmp = th.matmul(relations, tails).squeeze(-1)
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return th.bmm(tmp, heads)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
tails = th.transpose(tails, 1, 2)
heads = heads.unsqueeze(-1)
relations = relations.view(-1, self.relation_dim, self.entity_dim)
tmp = th.matmul(relations, heads).squeeze(-1)
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return th.bmm(tmp, tails)
return fn
......@@ -13,12 +13,14 @@ else:
from models.general_models import KEModel
from dataloader.sampler import create_neg_subgraph
def generate_rand_graph(n):
def generate_rand_graph(n, func_name):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
g = dgl.DGLGraph(arr, readonly=True)
num_rels = 10
entity_emb = F.uniform((g.number_of_nodes(), 10), F.float32, F.cpu(), 0, 1)
rel_emb = F.uniform((num_rels, 10), F.float32, F.cpu(), 0, 1)
if func_name == 'RESCAL':
rel_emb = F.uniform((num_rels, 10*10), F.float32, F.cpu(), 0, 1)
g.ndata['id'] = F.arange(0, g.number_of_nodes())
rel_ids = np.random.randint(0, num_rels, g.number_of_edges(), dtype=np.int64)
g.edata['id'] = F.tensor(rel_ids, F.int64)
......@@ -26,7 +28,8 @@ def generate_rand_graph(n):
ke_score_funcs = {'TransE': TransEScore(12.0),
'DistMult': DistMultScore(),
'ComplEx': ComplExScore()}
'ComplEx': ComplExScore(),
'RESCAL': RESCALScore(10, 10)}
class BaseKEModel:
def __init__(self, score_func, entity_emb, rel_emb):
......@@ -72,7 +75,7 @@ class BaseKEModel:
def check_score_func(func_name):
batch_size = 10
neg_sample_size = 10
g, entity_emb, rel_emb = generate_rand_graph(100)
g, entity_emb, rel_emb = generate_rand_graph(100, func_name)
hidden_dim = entity_emb.shape[1]
ke_score_func = ke_score_funcs[func_name]
model = BaseKEModel(ke_score_func, entity_emb, rel_emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment