Commit 8b17a5c1 authored by MilkshakeForReal's avatar MilkshakeForReal Committed by xiang song(charlie.song)
Browse files

[Model] add RotatE to dgl-kg (#964)

Add RotatE support for KGE (apps/kg)
Performance Result:
Dataset FB15k:
Result from Paper:
MR: 40
MRR: 0.797
HIT@1: 74.6
HIT@3: 83.0
HIT@10: 88.4

Our Impl:
MR: 39.6
MRR: 0.725
HIT@1: 62.8
HIT@3: 80.2
HIT@10: 87.5
parent 1915fdb9
...@@ -19,6 +19,7 @@ Contributors ...@@ -19,6 +19,7 @@ Contributors
* [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch * [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch
* [Chengqiang Lu](https://github.com/geekinglcq): MGCN, SchNet and MPNN in PyTorch * [Chengqiang Lu](https://github.com/geekinglcq): MGCN, SchNet and MPNN in PyTorch
* [Gongze Cao](https://github.com/Zardinality): Cluster GCN * [Gongze Cao](https://github.com/Zardinality): Cluster GCN
* [Yicheng Wu](https://github.com/MilkshakeForReal): RotatE in Pytorch
Other improvement Other improvement
* [Brett Koonce](https://github.com/brettkoonce) * [Brett Koonce](https://github.com/brettkoonce)
......
...@@ -20,6 +20,7 @@ DGL-KE includes the following knowledge graph embedding models: ...@@ -20,6 +20,7 @@ DGL-KE includes the following knowledge graph embedding models:
- ComplEx - ComplEx
- RESCAL - RESCAL
- TransR - TransR
- RotatE
It will add other popular models in the future. It will add other popular models in the future.
...@@ -61,10 +62,10 @@ The speed is measured with 16 CPU cores and one Nvidia V100 GPU. ...@@ -61,10 +62,10 @@ The speed is measured with 16 CPU cores and one Nvidia V100 GPU.
The speed on FB15k The speed on FB15k
| Models | TransE | DistMult | ComplEx | RESCAL | TransR | | Models | TransE | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|--------|----------|---------|--------|--------| |---------|--------|----------|---------|--------|--------|--------|
|MAX_STEPS| 20000 | 100000 | 100000 | 30000 | 100000 | |MAX_STEPS| 20000 | 100000 | 100000 | 30000 | 100000 | 100000 |
|TIME | 411s | 690s | 806s | 1800s | 7627s | |TIME | 411s | 690s | 806s | 1800s | 7627s | 4327s |
The accuracy on FB15k The accuracy on FB15k
...@@ -75,15 +76,16 @@ The accuracy on FB15k ...@@ -75,15 +76,16 @@ The accuracy on FB15k
| ComplEx | 51.99 | 0.785 | 0.720 | 0.832 | 0.889 | | ComplEx | 51.99 | 0.785 | 0.720 | 0.832 | 0.889 |
| RESCAL | 130.89| 0.668 | 0.597 | 0.720 | 0.800 | | RESCAL | 130.89| 0.668 | 0.597 | 0.720 | 0.800 |
| TransR | 138.7 | 0.501 | 0.274 | 0.704 | 0.801 | | TransR | 138.7 | 0.501 | 0.274 | 0.704 | 0.801 |
| RotatE | 39.6 | 0.725 | 0.628 | 0.802 | 0.875 |
In comparison, GraphVite uses 4 GPUs and takes 14 minutes. Thus, DGL-KE trains TransE on FB15k twice as fast as GraphVite while using much few resources. More performance information on GraphVite can be found [here](https://github.com/DeepGraphLearning/graphvite). In comparison, GraphVite uses 4 GPUs and takes 14 minutes. Thus, DGL-KE trains TransE on FB15k twice as fast as GraphVite while using much few resources. More performance information on GraphVite can be found [here](https://github.com/DeepGraphLearning/graphvite).
The speed on wn18 The speed on wn18
| Models | TransE | DistMult | ComplEx | RESCAL | TransR | | Models | TransE | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|--------|----------|---------|--------|--------| |---------|--------|----------|---------|--------|--------|--------|
|MAX_STEPS| 40000 | 10000 | 20000 | 20000 | 20000 | |MAX_STEPS| 40000 | 10000 | 20000 | 20000 | 20000 | 20000 |
|TIME | 719s | 126s | 266s | 333s | 1547s | |TIME | 719s | 126s | 266s | 333s | 1547s | 786s |
The accuracy on wn18 The accuracy on wn18
...@@ -94,6 +96,7 @@ The accuracy on wn18 ...@@ -94,6 +96,7 @@ The accuracy on wn18
| ComplEx | 276.37 | 0.935 | 0.916 | 0.950 | 0.960 | | ComplEx | 276.37 | 0.935 | 0.916 | 0.950 | 0.960 |
| RESCAL | 579.54 | 0.846 | 0.791 | 0.898 | 0.931 | | RESCAL | 579.54 | 0.846 | 0.791 | 0.898 | 0.931 |
| TransR | 615.56 | 0.606 | 0.378 | 0.826 | 0.890 | | TransR | 615.56 | 0.606 | 0.378 | 0.826 | 0.890 |
| RotatE | 367.64 | 0.931 | 0.924 | 0.935 | 0.944 |
The speed on Freebase The speed on Freebase
......
...@@ -22,6 +22,10 @@ DGLBACKEND=pytorch python3 train.py --model TransR --dataset FB15k --batch_size ...@@ -22,6 +22,10 @@ DGLBACKEND=pytorch python3 train.py --model TransR --dataset FB15k --batch_size
--neg_sample_size 256 --hidden_dim 500 --gamma 24.0 --lr 0.01 --max_step 30000 \ --neg_sample_size 256 --hidden_dim 500 --gamma 24.0 --lr 0.01 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --batch_size_eval 16 --gpu 0 --valid --test -adv
DGLBACKEND=pytorch python3 train.py --model RotatE --dataset FB15k --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 12.0 --lr 0.01 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv -de --regularization_coef=1e-4
# for wn18 # for wn18
DGLBACKEND=pytorch python3 train.py --model TransE --dataset wn18 --batch_size 1024 \ DGLBACKEND=pytorch python3 train.py --model TransE --dataset wn18 --batch_size 1024 \
...@@ -45,6 +49,10 @@ DGLBACKEND=pytorch python3 train.py --model TransR --dataset wn18 --batch_size 1 ...@@ -45,6 +49,10 @@ DGLBACKEND=pytorch python3 train.py --model TransR --dataset wn18 --batch_size 1
--neg_sample_size 256 --hidden_dim 500 --gamma 16.0 --lr 0.1 --max_step 30000 \ --neg_sample_size 256 --hidden_dim 500 --gamma 16.0 --lr 0.1 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --batch_size_eval 16 --gpu 0 --valid --test -adv
DGLBACKEND=pytorch python3 train.py --model RotatE --dataset wn18 --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 12.0 --lr 0.02 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv -de
# for Freebase # for Freebase
DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset Freebase --batch_size 1024 \ DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset Freebase --batch_size 1024 \
......
...@@ -58,6 +58,8 @@ class KEModel(object): ...@@ -58,6 +58,8 @@ class KEModel(object):
self.score_func = ComplExScore() self.score_func = ComplExScore()
elif model_name == 'RESCAL': elif model_name == 'RESCAL':
self.score_func = RESCALScore(relation_dim, entity_dim) self.score_func = RESCALScore(relation_dim, entity_dim)
elif model_name == 'RotatE':
self.score_func = RotatEScore(gamma, self.emb_init)
self.head_neg_score = self.score_func.create_neg(True) self.head_neg_score = self.score_func.create_neg(True)
self.tail_neg_score = self.score_func.create_neg(False) self.tail_neg_score = self.score_func.create_neg(False)
......
import numpy as np
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
from mxnet.gluon import nn from mxnet.gluon import nn
...@@ -359,3 +360,92 @@ class RESCALScore(nn.Block): ...@@ -359,3 +360,92 @@ class RESCALScore(nn.Block):
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim) tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return nd.linalg_gemm2(tmp, tails) return nd.linalg_gemm2(tmp, tails)
return fn return fn
class RotatEScore(nn.Block):
def __init__(self, gamma, emb_init, eps=1e-10):
super(RotatEScore, self).__init__()
self.gamma = gamma
self.emb_init = emb_init
self.eps = eps
def edge_func(self, edges):
real_head, img_head = nd.split(edges.src['emb'], num_outputs=2, axis=-1)
real_tail, img_tail = nd.split(edges.dst['emb'], num_outputs=2, axis=-1)
phase_rel = edges.data['emb'] / (self.emb_init / np.pi)
re_rel, im_rel = nd.cos(phase_rel), nd.sin(phase_rel)
real_score = real_head * re_rel - img_head * im_rel
img_score = real_head * im_rel + img_head * re_rel
real_score = real_score - real_tail
img_score = img_score - img_tail
#sqrt((x*x).sum() + eps)
score = mx.nd.sqrt(real_score * real_score + img_score * img_score + self.eps).sum(-1)
return {'score': self.gamma - score}
def prepare(self, g, gpu_id, trace=False):
pass
def create_neg_prepare(self, neg_head):
def fn(rel_id, num_chunks, head, tail, gpu_id, trace=False):
return head, tail
return fn
def update(self):
pass
def reset_parameters(self):
pass
def save(self, path, name):
pass
def load(self, path, name):
pass
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def create_neg(self, neg_head):
gamma = self.gamma
emb_init = self.emb_init
eps = self.eps
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
emb_real, emb_img = nd.split(tails, num_outputs=2, axis=-1)
phase_rel = relations / (emb_init / np.pi)
rel_real, rel_img = nd.cos(phase_rel), nd.sin(phase_rel)
real = emb_real * rel_real + emb_img * rel_img
img = -emb_real * rel_img + emb_img * rel_real
emb_complex = nd.concat(real, img, dim=-1)
tmp = emb_complex.reshape(num_chunks, chunk_size, 1, hidden_dim)
heads = heads.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
score = tmp - heads
score_real, score_img = nd.split(score, num_outputs=2, axis=-1)
score = mx.nd.sqrt(score_real * score_real + score_img * score_img + self.eps).sum(-1)
return gamma - score
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
emb_real, emb_img = nd.split(heads, num_outputs=2, axis=-1)
phase_rel = relations / (emb_init / np.pi)
rel_real, rel_img = nd.cos(phase_rel), nd.sin(phase_rel)
real = emb_real * rel_real - emb_img * rel_img
img = emb_real * rel_img + emb_img * rel_real
emb_complex = nd.concat(real, img, dim=-1)
tmp = emb_complex.reshape(num_chunks, chunk_size, 1, hidden_dim)
tails = tails.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
score = tmp - tails
score_real, score_img = nd.split(score, num_outputs=2, axis=-1)
score = mx.nd.sqrt(score_real * score_real + score_img * score_img + self.eps).sum(-1)
return gamma - score
return fn
...@@ -2,8 +2,7 @@ import torch as th ...@@ -2,8 +2,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as functional import torch.nn.functional as functional
import torch.nn.init as INIT import torch.nn.init as INIT
import numpy as np
from .tensor_models import ExternalEmbedding
class TransEScore(nn.Module): class TransEScore(nn.Module):
def __init__(self, gamma): def __init__(self, gamma):
...@@ -333,3 +332,90 @@ class RESCALScore(nn.Module): ...@@ -333,3 +332,90 @@ class RESCALScore(nn.Module):
tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim) tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
return th.bmm(tmp, tails) return th.bmm(tmp, tails)
return fn return fn
class RotatEScore(nn.Module):
def __init__(self, gamma, emb_init):
super(RotatEScore, self).__init__()
self.gamma = gamma
self.emb_init = emb_init
def edge_func(self, edges):
re_head, im_head = th.chunk(edges.src['emb'], 2, dim=-1)
re_tail, im_tail = th.chunk(edges.dst['emb'], 2, dim=-1)
phase_rel = edges.data['emb'] / (self.emb_init / np.pi)
re_rel, im_rel = th.cos(phase_rel), th.sin(phase_rel)
re_score = re_head * re_rel - im_head * im_rel
im_score = re_head * im_rel + im_head * re_rel
re_score = re_score - re_tail
im_score = im_score - im_tail
score = th.stack([re_score, im_score], dim=0)
score = score.norm(dim=0)
return {'score': self.gamma - score.sum(-1)}
def update(self):
pass
def reset_parameters(self):
pass
def save(self, path, name):
pass
def load(self, path, name):
pass
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def create_neg_prepare(self, neg_head):
def fn(rel_id, num_chunks, head, tail, gpu_id, trace=False):
return head, tail
return fn
def prepare(self, g, gpu_id, trace=False):
pass
def create_neg(self, neg_head):
gamma = self.gamma
emb_init = self.emb_init
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
emb_real = tails[..., :hidden_dim // 2]
emb_imag = tails[..., hidden_dim // 2:]
phase_rel = relations / (emb_init / np.pi)
rel_real, rel_imag = th.cos(phase_rel), th.sin(phase_rel)
real = emb_real * rel_real + emb_imag * rel_imag
imag = -emb_real * rel_imag + emb_imag * rel_real
emb_complex = th.cat((real, imag), dim=-1)
tmp = emb_complex.reshape(num_chunks, chunk_size, 1, hidden_dim)
heads = heads.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
score = tmp - heads
score = th.stack([score[..., :hidden_dim // 2],
score[..., hidden_dim // 2:]], dim=-1).norm(dim=-1)
return gamma - score.sum(-1)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
emb_real = heads[..., :hidden_dim // 2]
emb_imag = heads[..., hidden_dim // 2:]
phase_rel = relations / (emb_init / np.pi)
rel_real, rel_imag = th.cos(phase_rel), th.sin(phase_rel)
real = emb_real * rel_real - emb_imag * rel_imag
imag = emb_real * rel_imag + emb_imag * rel_real
emb_complex = th.cat((real, imag), dim=-1)
tmp = emb_complex.reshape(num_chunks, chunk_size, 1, hidden_dim)
tails = tails.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
score = tmp - tails
score = th.stack([score[..., :hidden_dim // 2],
score[..., hidden_dim // 2:]], dim=-1).norm(dim=-1)
return gamma - score.sum(-1)
return fn
...@@ -34,13 +34,14 @@ def generate_rand_graph(n, func_name): ...@@ -34,13 +34,14 @@ def generate_rand_graph(n, func_name):
g = dgl.DGLGraph(arr, readonly=True) g = dgl.DGLGraph(arr, readonly=True)
num_rels = 10 num_rels = 10
entity_emb = F.uniform((g.number_of_nodes(), 10), F.float32, F.cpu(), 0, 1) entity_emb = F.uniform((g.number_of_nodes(), 10), F.float32, F.cpu(), 0, 1)
rel_emb = F.uniform((num_rels, 10), F.float32, F.cpu(), 0, 1) if func_name == 'RotatE':
entity_emb = F.uniform((g.number_of_nodes(), 20), F.float32, F.cpu(), 0, 1)
rel_emb = F.uniform((num_rels, 10), F.float32, F.cpu(), -1, 1)
if func_name == 'RESCAL': if func_name == 'RESCAL':
rel_emb = F.uniform((num_rels, 10*10), F.float32, F.cpu(), 0, 1) rel_emb = F.uniform((num_rels, 10*10), F.float32, F.cpu(), 0, 1)
g.ndata['id'] = F.arange(0, g.number_of_nodes()) g.ndata['id'] = F.arange(0, g.number_of_nodes())
rel_ids = np.random.randint(0, num_rels, g.number_of_edges(), dtype=np.int64) rel_ids = np.random.randint(0, num_rels, g.number_of_edges(), dtype=np.int64)
g.edata['id'] = F.tensor(rel_ids, F.int64) g.edata['id'] = F.tensor(rel_ids, F.int64)
# TransR have additional projection_emb # TransR have additional projection_emb
if (func_name == 'TransR'): if (func_name == 'TransR'):
args = {'gpu':-1, 'lr':0.1} args = {'gpu':-1, 'lr':0.1}
...@@ -51,6 +52,8 @@ def generate_rand_graph(n, func_name): ...@@ -51,6 +52,8 @@ def generate_rand_graph(n, func_name):
return g, entity_emb, rel_emb, (12.0) return g, entity_emb, rel_emb, (12.0)
elif (func_name == 'RESCAL'): elif (func_name == 'RESCAL'):
return g, entity_emb, rel_emb, (10, 10) return g, entity_emb, rel_emb, (10, 10)
elif (func_name == 'RotatE'):
return g, entity_emb, rel_emb, (12.0, 1.0)
else: else:
return g, entity_emb, rel_emb, None return g, entity_emb, rel_emb, None
...@@ -58,7 +61,8 @@ ke_score_funcs = {'TransE': TransEScore, ...@@ -58,7 +61,8 @@ ke_score_funcs = {'TransE': TransEScore,
'DistMult': DistMultScore, 'DistMult': DistMultScore,
'ComplEx': ComplExScore, 'ComplEx': ComplExScore,
'RESCAL': RESCALScore, 'RESCAL': RESCALScore,
'TransR': TransRScore} 'TransR': TransRScore,
'RotatE': RotatEScore}
class BaseKEModel: class BaseKEModel:
def __init__(self, score_func, entity_emb, rel_emb): def __init__(self, score_func, entity_emb, rel_emb):
...@@ -158,9 +162,13 @@ def test_score_func_rescal(): ...@@ -158,9 +162,13 @@ def test_score_func_rescal():
def test_score_func_transr(): def test_score_func_transr():
check_score_func('TransR') check_score_func('TransR')
def test_score_func_rotate():
check_score_func('RotatE')
if __name__ == '__main__': if __name__ == '__main__':
test_score_func_transe() test_score_func_transe()
test_score_func_distmult() test_score_func_distmult()
test_score_func_complex() test_score_func_complex()
test_score_func_rescal() test_score_func_rescal()
test_score_func_transr() test_score_func_transr()
test_score_func_rotate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment