"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fba4f42e3bc24b7b2c6cad09b6db653ac73dc6b7"
Commit dca0e376 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Da Zheng
Browse files

[KG][Score Func] Update TransE with L2 distance support. (#1059)

* Add L2 distance score for TransE

* Update README.md

* Use linalg.gemm to speedup mx l2 dist

* Fix
parent d6dfaa9b
......@@ -15,7 +15,7 @@ takes a couple of hours on Freebase, which has hundreds of millions of edges.
DGL-KE includes the following knowledge graph embedding models:
- TransE
- TransE (TransE_l1 with L1 distance and TransE_l2 with L2 distance)
- DistMult
- ComplEx
- RESCAL
......@@ -62,41 +62,43 @@ The speed is measured with 16 CPU cores and one Nvidia V100 GPU.
The speed on FB15k
| Models | TransE | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|--------|----------|---------|--------|--------|--------|
|MAX_STEPS| 20000 | 100000 | 100000 | 30000 | 100000 | 100000 |
|TIME | 411s | 690s | 806s | 1800s | 7627s | 4327s |
| Models | TransE_l1 | TransE_l2 | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|-----------|-----------|----------|---------|--------|--------|--------|
|MAX_STEPS| 20000 | 30000 |100000 | 100000 | 30000 | 100000 | 100000 |
|TIME | 411s | 329s |690s | 806s | 1800s | 7627s | 4327s |
The accuracy on FB15k
| Models | MR | MRR | HITS@1 | HITS@3 | HITS@10 |
|----------|-------|-------|--------|--------|---------|
| TransE | 69.12 | 0.656 | 0.567 | 0.718 | 0.802 |
| DistMult | 43.35 | 0.783 | 0.713 | 0.837 | 0.897 |
| ComplEx | 51.99 | 0.785 | 0.720 | 0.832 | 0.889 |
| RESCAL | 130.89| 0.668 | 0.597 | 0.720 | 0.800 |
| TransR | 138.7 | 0.501 | 0.274 | 0.704 | 0.801 |
| RotatE | 39.6 | 0.725 | 0.628 | 0.802 | 0.875 |
| Models | MR | MRR | HITS@1 | HITS@3 | HITS@10 |
|-----------|-------|-------|--------|--------|---------|
| TransE_l1 | 69.12 | 0.656 | 0.567 | 0.718 | 0.802 |
| TransE_l2 | 35.86 | 0.570 | 0.400 | 0.708 | 0.834 |
| DistMult | 43.35 | 0.783 | 0.713 | 0.837 | 0.897 |
| ComplEx | 51.99 | 0.785 | 0.720 | 0.832 | 0.889 |
| RESCAL | 130.89| 0.668 | 0.597 | 0.720 | 0.800 |
| TransR | 138.7 | 0.501 | 0.274 | 0.704 | 0.801 |
| RotatE | 39.6 | 0.725 | 0.628 | 0.802 | 0.875 |
In comparison, GraphVite uses 4 GPUs and takes 14 minutes. Thus, DGL-KE trains TransE on FB15k twice as fast as GraphVite while using much few resources. More performance information on GraphVite can be found [here](https://github.com/DeepGraphLearning/graphvite).
The speed on wn18
| Models | TransE | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|--------|----------|---------|--------|--------|--------|
|MAX_STEPS| 40000 | 10000 | 20000 | 20000 | 20000 | 20000 |
|TIME | 719s | 126s | 266s | 333s | 1547s | 786s |
| Models | TransE_l1 | TransE_l2 | DistMult | ComplEx | RESCAL | TransR | RotatE |
|---------|-----------|-----------|----------|---------|--------|--------|--------|
|MAX_STEPS| 40000 | 20000 | 10000 | 20000 | 20000 | 20000 | 20000 |
|TIME | 719s | 254s | 126s | 266s | 333s | 1547s | 786s |
The accuracy on wn18
| Models | MR | MRR | HITS@1 | HITS@3 | HITS@10 |
|----------|--------|-------|--------|--------|---------|
| TransE | 321.35 | 0.760 | 0.652 | 0.850 | 0.940 |
| DistMult | 271.09 | 0.769 | 0.639 | 0.892 | 0.949 |
| ComplEx | 276.37 | 0.935 | 0.916 | 0.950 | 0.960 |
| RESCAL | 579.54 | 0.846 | 0.791 | 0.898 | 0.931 |
| TransR | 615.56 | 0.606 | 0.378 | 0.826 | 0.890 |
| RotatE | 367.64 | 0.931 | 0.924 | 0.935 | 0.944 |
| Models | MR | MRR | HITS@1 | HITS@3 | HITS@10 |
|-----------|--------|-------|--------|--------|---------|
| TransE_l1 | 321.35 | 0.760 | 0.652 | 0.850 | 0.940 |
| TransE_l2 | 181.57 | 0.570 | 0.322 | 0.802 | 0.944 |
| DistMult | 271.09 | 0.769 | 0.639 | 0.892 | 0.949 |
| ComplEx | 276.37 | 0.935 | 0.916 | 0.950 | 0.960 |
| RESCAL | 579.54 | 0.846 | 0.791 | 0.898 | 0.931 |
| TransR | 615.56 | 0.606 | 0.378 | 0.826 | 0.890 |
| RotatE | 367.64 | 0.931 | 0.924 | 0.935 | 0.944 |
The speed on Freebase
......
......@@ -10,10 +10,14 @@ DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset FB15k --batch_size
--neg_sample_size 256 --hidden_dim 2000 --gamma 500.0 --lr 0.2 --max_step 100000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
DGLBACKEND=pytorch python3 train.py --model TransE --dataset FB15k --batch_size 1024 \
DGLBACKEND=pytorch python3 train.py --model TransE_l1 --dataset FB15k --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 2000 --gamma 24.0 --lr 0.01 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset FB15k --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 2000 --gamma 12.0 --lr 0.1 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --regularization_coef=2e-7
DGLBACKEND=pytorch python3 train.py --model RESCAL --dataset FB15k --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 500 --gamma 24.0 --lr 0.03 --max_step 30000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv
......@@ -28,11 +32,15 @@ DGLBACKEND=pytorch python3 train.py --model RotatE --dataset FB15k --batch_size
# for wn18
DGLBACKEND=pytorch python3 train.py --model TransE --dataset wn18 --batch_size 1024 \
DGLBACKEND=pytorch python3 train.py --model TransE_l1 --dataset wn18 --batch_size 1024 \
--neg_sample_size 512 --hidden_dim 500 --gamma 12.0 --adversarial_temperature 0.5 \
--lr 0.01 --max_step 40000 --batch_size_eval 16 --gpu 0 --valid --test -adv \
--regularization_coef 0.00001
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset wn18 --batch_size 1024 \
--neg_sample_size 512 --hidden_dim 500 --gamma 6.0 --lr 0.1 --max_step 20000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --regularization_coef 0.0000001
DGLBACKEND=pytorch python3 train.py --model DistMult --dataset wn18 --batch_size 1024 \
--neg_sample_size 1024 --hidden_dim 1000 --gamma 200.0 --lr 0.1 --max_step 10000 \
--batch_size_eval 16 --gpu 0 --valid --test -adv --regularization_coef 0.00001
......
......@@ -46,8 +46,10 @@ class KEModel(object):
rel_dim = relation_dim
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim, device)
if model_name == 'TransE':
self.score_func = TransEScore(gamma)
if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2')
elif model_name == 'TransE_l1':
self.score_func = TransEScore(gamma, 'l1')
elif model_name == 'TransR':
projection_emb = ExternalEmbedding(args, n_relations, entity_dim * relation_dim,
F.cpu() if args.mix_cpu_gpu else device)
......
......@@ -4,17 +4,39 @@ from mxnet import gluon
from mxnet.gluon import nn
from mxnet import ndarray as nd
def batched_l2_dist(a, b):
a_squared = nd.power(nd.norm(a, axis=-1), 2)
b_squared = nd.power(nd.norm(b, axis=-1), 2)
squared_res = nd.add(nd.linalg_gemm(
a, nd.transpose(b, axes=(0, 2, 1)), nd.broadcast_axes(nd.expand_dims(b_squared, axis=-2), axis=1, size=a.shape[1]), alpha=-2
), nd.expand_dims(a_squared, axis=-1))
res = nd.sqrt(nd.clip(squared_res, 1e-30, np.finfo(np.float32).max))
return res
def batched_l1_dist(a, b):
a = nd.expand_dims(a, axis=-2)
b = nd.expand_dims(b, axis=-3)
res = nd.norm(a - b, ord=1, axis=-1)
return res
class TransEScore(nn.Block):
def __init__(self, gamma):
def __init__(self, gamma, dist_func='l2'):
super(TransEScore, self).__init__()
self.gamma = gamma
if dist_func == 'l1':
self.neg_dist_func = batched_l1_dist
self.dist_ord = 1
else: # default use l2
self.neg_dist_func = batched_l2_dist
self.dist_ord = 2
def edge_func(self, edges):
head = edges.src['emb']
tail = edges.dst['emb']
rel = edges.data['emb']
score = head + rel - tail
return {'score': self.gamma - nd.norm(score, ord=1, axis=-1)}
return {'score': self.gamma - nd.norm(score, ord=self.dist_ord, axis=-1)}
def prepare(self, g, gpu_id, trace=False):
pass
......@@ -44,18 +66,18 @@ class TransEScore(nn.Block):
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
tails = tails - relations
tails = tails.reshape(num_chunks,chunk_size, 1, hidden_dim)
return gamma - nd.norm(heads - tails, ord=1, axis=-1)
tails = tails.reshape(num_chunks, chunk_size, hidden_dim)
return gamma - self.neg_dist_func(tails, heads)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads + relations
heads = heads.reshape(num_chunks, chunk_size, 1, hidden_dim)
tails = tails.reshape(num_chunks, 1, neg_sample_size, hidden_dim)
return gamma - nd.norm(heads - tails, ord=1, axis=-1)
heads = heads.reshape(num_chunks, chunk_size, hidden_dim)
tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
return gamma - self.neg_dist_func(heads, tails)
return fn
class TransRScore(nn.Block):
......
......@@ -4,17 +4,37 @@ import torch.nn.functional as functional
import torch.nn.init as INIT
import numpy as np
def batched_l2_dist(a, b):
a_squared = a.norm(dim=-1).pow(2)
b_squared = b.norm(dim=-1).pow(2)
squared_res = th.baddbmm(
b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2
).add_(a_squared.unsqueeze(-1))
res = squared_res.clamp_min_(1e-30).sqrt_()
return res
def batched_l1_dist(a, b):
res = th.cdist(a, b, p=1)
return res
class TransEScore(nn.Module):
def __init__(self, gamma):
def __init__(self, gamma, dist_func='l2'):
super(TransEScore, self).__init__()
self.gamma = gamma
if dist_func == 'l1':
self.neg_dist_func = batched_l1_dist
self.dist_ord = 1
else: # default use l2
self.neg_dist_func = batched_l2_dist
self.dist_ord = 2
def edge_func(self, edges):
head = edges.src['emb']
tail = edges.dst['emb']
rel = edges.data['emb']
score = head + rel - tail
return {'score': self.gamma - th.norm(score, p=1, dim=-1)}
return {'score': self.gamma - th.norm(score, p=self.dist_ord, dim=-1)}
def prepare(self, g, gpu_id, trace=False):
pass
......@@ -47,7 +67,7 @@ class TransEScore(nn.Module):
heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
tails = tails - relations
tails = tails.reshape(num_chunks, chunk_size, hidden_dim)
return gamma - th.cdist(tails, heads, p=1)
return gamma - self.neg_dist_func(tails, heads)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
......@@ -55,7 +75,7 @@ class TransEScore(nn.Module):
heads = heads + relations
heads = heads.reshape(num_chunks, chunk_size, hidden_dim)
tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
return gamma - th.cdist(heads, tails, p=1)
return gamma - self.neg_dist_func(heads, tails)
return fn
class TransRScore(nn.Module):
......
......@@ -50,6 +50,10 @@ def generate_rand_graph(n, func_name):
return g, entity_emb, rel_emb, (12.0, projection_emb, 10, 10)
elif (func_name == 'TransE'):
return g, entity_emb, rel_emb, (12.0)
elif (func_name == 'TransE_l1'):
return g, entity_emb, rel_emb, (12.0, 'l1')
elif (func_name == 'TransE_l2'):
return g, entity_emb, rel_emb, (12.0, 'l2')
elif (func_name == 'RESCAL'):
return g, entity_emb, rel_emb, (10, 10)
elif (func_name == 'RotatE'):
......@@ -58,6 +62,8 @@ def generate_rand_graph(n, func_name):
return g, entity_emb, rel_emb, None
ke_score_funcs = {'TransE': TransEScore,
'TransE_l1': TransEScore,
'TransE_l2': TransEScore,
'DistMult': DistMultScore,
'ComplEx': ComplExScore,
'RESCAL': RESCALScore,
......@@ -149,6 +155,8 @@ def check_score_func(func_name):
def test_score_func_transe():
check_score_func('TransE')
check_score_func('TransE_l1')
check_score_func('TransE_l2')
def test_score_func_distmult():
check_score_func('DistMult')
......
......@@ -23,8 +23,8 @@ class ArgParser(argparse.ArgumentParser):
super(ArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransH', 'TransR', 'TransD',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE'],
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='model to use')
self.add_argument('--data_path', type=str, default='data',
help='root path of all dataset')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment