"server/vscode:/vscode.git/clone" did not exist on "15511edc01a0725d374840f0e77d085eb5821483"
test_score.py 7.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
22
23
24
25
26
import os
import scipy as sp
import dgl
import numpy as np
import dgl.backend as F
import dgl

Da Zheng's avatar
Da Zheng committed
27
backend = os.environ.get('DGLBACKEND', 'pytorch')
28
if backend.lower() == 'mxnet':
29
30
31
32
    import mxnet as mx
    mx.random.seed(42)
    np.random.seed(42)

33
    from models.mxnet.score_fun import *
34
    from models.mxnet.tensor_models import ExternalEmbedding
35
else:
36
37
38
39
    import torch as th
    th.manual_seed(42)
    np.random.seed(42)

40
    from models.pytorch.score_fun import *
41
    from models.pytorch.tensor_models import ExternalEmbedding
42
43
44
from models.general_models import KEModel
from dataloader.sampler import create_neg_subgraph

45
46
47
48
49
50
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

51
def generate_rand_graph(n, func_name):
52
53
54
55
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    g = dgl.DGLGraph(arr, readonly=True)
    num_rels = 10
    entity_emb = F.uniform((g.number_of_nodes(), 10), F.float32, F.cpu(), 0, 1)
56
57
58
    if func_name == 'RotatE':
        entity_emb = F.uniform((g.number_of_nodes(), 20), F.float32, F.cpu(), 0, 1)
    rel_emb = F.uniform((num_rels, 10), F.float32, F.cpu(), -1, 1)
59
60
    if func_name == 'RESCAL':
        rel_emb = F.uniform((num_rels, 10*10), F.float32, F.cpu(), 0, 1)
61
62
63
    g.ndata['id'] = F.arange(0, g.number_of_nodes())
    rel_ids = np.random.randint(0, num_rels, g.number_of_edges(), dtype=np.int64)
    g.edata['id'] = F.tensor(rel_ids, F.int64)
64
65
66
67
68
69
70
71
    # TransR have additional projection_emb
    if (func_name == 'TransR'):
        args = {'gpu':-1, 'lr':0.1}
        args = dotdict(args)
        projection_emb = ExternalEmbedding(args, 10, 10 * 10, F.cpu())
        return g, entity_emb, rel_emb, (12.0, projection_emb, 10, 10)
    elif (func_name == 'TransE'):
        return g, entity_emb, rel_emb, (12.0)
72
73
74
75
    elif (func_name == 'TransE_l1'):
        return g, entity_emb, rel_emb, (12.0, 'l1')
    elif (func_name == 'TransE_l2'):
        return g, entity_emb, rel_emb, (12.0, 'l2')
76
77
    elif (func_name == 'RESCAL'):
        return g, entity_emb, rel_emb, (10, 10)
78
79
    elif (func_name == 'RotatE'):
        return g, entity_emb, rel_emb, (12.0, 1.0)
80
81
82
83
    else:
        return g, entity_emb, rel_emb, None

ke_score_funcs = {'TransE': TransEScore,
84
85
                  'TransE_l1': TransEScore,
                  'TransE_l2': TransEScore,
86
87
88
                  'DistMult': DistMultScore,
                  'ComplEx': ComplExScore,
                  'RESCAL': RESCALScore,
89
90
                  'TransR': TransRScore,
                  'RotatE': RotatEScore}
91
92
93
94
95
96

class BaseKEModel:
    def __init__(self, score_func, entity_emb, rel_emb):
        self.score_func = score_func
        self.head_neg_score = self.score_func.create_neg(True)
        self.tail_neg_score = self.score_func.create_neg(False)
97
98
        self.head_neg_prepare = self.score_func.create_neg_prepare(True)
        self.tail_neg_prepare = self.score_func.create_neg_prepare(False)
99
100
        self.entity_emb = entity_emb
        self.rel_emb = rel_emb
101
102
        # init score_func specific data if needed
        self.score_func.reset_parameters()
103
104
105
106

    def predict_score(self, g):
        g.ndata['emb'] = self.entity_emb[g.ndata['id']]
        g.edata['emb'] = self.rel_emb[g.edata['id']]
107
        self.score_func.prepare(g, -1, False)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        self.score_func(g)
        return g.edata['score']

    def predict_neg_score(self, pos_g, neg_g):
        pos_g.ndata['emb'] = self.entity_emb[pos_g.ndata['id']]
        pos_g.edata['emb'] = self.rel_emb[pos_g.edata['id']]
        neg_g.ndata['emb'] = self.entity_emb[neg_g.ndata['id']]
        neg_g.edata['emb'] = self.rel_emb[neg_g.edata['id']]
        num_chunks = neg_g.num_chunks
        chunk_size = neg_g.chunk_size
        neg_sample_size = neg_g.neg_sample_size
        if neg_g.neg_head:
            neg_head_ids = neg_g.ndata['id'][neg_g.head_nid]
            neg_head = self.entity_emb[neg_head_ids]
            _, tail_ids = pos_g.all_edges(order='eid')
            tail = pos_g.ndata['emb'][tail_ids]
            rel = pos_g.edata['emb']
125
126

            neg_head, tail = self.head_neg_prepare(pos_g.edata['id'], num_chunks, neg_head, tail, -1, False)
127
128
129
130
131
132
133
134
            neg_score = self.head_neg_score(neg_head, rel, tail,
                                            num_chunks, chunk_size, neg_sample_size)
        else:
            neg_tail_ids = neg_g.ndata['id'][neg_g.tail_nid]
            neg_tail = self.entity_emb[neg_tail_ids]
            head_ids, _ = pos_g.all_edges(order='eid')
            head = pos_g.ndata['emb'][head_ids]
            rel = pos_g.edata['emb']
135
136

            head, neg_tail = self.tail_neg_prepare(pos_g.edata['id'], num_chunks, head, neg_tail, -1, False)
137
138
139
140
141
142
143
144
            neg_score = self.tail_neg_score(head, rel, neg_tail,
                                            num_chunks, chunk_size, neg_sample_size)

        return neg_score

def check_score_func(func_name):
    batch_size = 10
    neg_sample_size = 10
145
    g, entity_emb, rel_emb, args = generate_rand_graph(100, func_name)
146
147
    hidden_dim = entity_emb.shape[1]
    ke_score_func = ke_score_funcs[func_name]
148
149
150
151
152
153
    if args is None:
        ke_score_func = ke_score_func()
    elif type(args) is tuple:
        ke_score_func = ke_score_func(*list(args))
    else:
        ke_score_func = ke_score_func(args)
154
155
156
157
158
    model = BaseKEModel(ke_score_func, entity_emb, rel_emb)

    EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
    sampler = EdgeSampler(g, batch_size=batch_size,
                          neg_sample_size=neg_sample_size,
159
                          negative_mode='chunk-head',
160
161
162
163
164
165
                          num_workers=1,
                          shuffle=False,
                          exclude_positive=False,
                          return_false_neg=False)

    for pos_g, neg_g in sampler:
166
167
168
169
170
171
172
        neg_g = create_neg_subgraph(pos_g,
                                    neg_g,
                                    neg_sample_size,
                                    neg_sample_size,
                                    True,
                                    True,
                                    g.number_of_nodes())
173
174
175
176
177
178
179
180
        pos_g.copy_from_parent()
        neg_g.copy_from_parent()
        score1 = F.reshape(model.predict_score(neg_g), (batch_size, -1))
        score2 = model.predict_neg_score(pos_g, neg_g)
        score2 = F.reshape(score2, (batch_size, -1))
        np.testing.assert_allclose(F.asnumpy(score1), F.asnumpy(score2),
                                   rtol=1e-5, atol=1e-5)

181
182
def test_score_func_transe():
    check_score_func('TransE')
183
184
    check_score_func('TransE_l1')
    check_score_func('TransE_l2')
185
186
187
188
189
190
191
192
193
194
195
196

def test_score_func_distmult():
    check_score_func('DistMult')

def test_score_func_complex():
    check_score_func('ComplEx')

def test_score_func_rescal():
    check_score_func('RESCAL')

def test_score_func_transr():
    check_score_func('TransR')
197

198
199
200
def test_score_func_rotate():
    check_score_func('RotatE')
        
201
if __name__ == '__main__':
202
203
204
205
206
    test_score_func_transe()
    test_score_func_distmult()
    test_score_func_complex()
    test_score_func_rescal()
    test_score_func_transr()
207
    test_score_func_rotate()