tensor_models.py 5.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
22
"""
KG Sparse embedding
"""
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import numpy as np
import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd

from .score_fun import *
from .. import *

def logsigmoid(val):
    max_elem = nd.maximum(0., -val)
    z = nd.exp(-max_elem) + nd.exp(-val - max_elem)
    return -(max_elem + nd.log(z))

37
get_device = lambda args : mx.gpu(args.gpu[0]) if args.gpu[0] >= 0 else mx.cpu()
38
39
40
41
42
43
44
45
46
norm = lambda x, p: nd.sum(nd.abs(x) ** p)

get_scalar = lambda x: x.detach().asscalar()

reshape = lambda arr, x, y: arr.reshape(x, y)

cuda = lambda arr, gpu: arr.as_in_context(mx.gpu(gpu))

class ExternalEmbedding:
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    """Sparse Embedding for Knowledge Graph
    It is used to store both entity embeddings and relation embeddings.

    Parameters
    ----------
    args :
        Global configs.
    num : int
        Number of embeddings.
    dim : int
        Embedding dimention size.
    ctx : mx.ctx
        Device context to store the embedding.
    """
61
62
63
64
65
66
67
68
69
70
    def __init__(self, args, num, dim, ctx):
        self.gpu = args.gpu
        self.args = args
        self.trace = []

        self.emb = nd.empty((num, dim), dtype=np.float32, ctx=ctx)
        self.state_sum = nd.zeros((self.emb.shape[0]), dtype=np.float32, ctx=ctx)
        self.state_step = 0

    def init(self, emb_init):
71
72
73
74
75
76
77
        """Initializing the embeddings.

        Parameters
        ----------
        emb_init : float
            The intial embedding range should be [-emb_init, emb_init].
        """
78
79
80
81
82
83
84
85
86
        nd.random.uniform(-emb_init, emb_init,
                          shape=self.emb.shape, dtype=self.emb.dtype,
                          ctx=self.emb.context, out=self.emb)

    def share_memory(self):
        # TODO(zhengda) fix this later
        pass

    def __call__(self, idx, gpu_id=-1, trace=True):
87
88
89
90
91
92
93
94
95
96
97
98
99
        """ Return sliced tensor.

        Parameters
        ----------
        idx : th.tensor
            Slicing index
        gpu_id : int
            Which gpu to put sliced data in.
        trace : bool
            If True, trace the computation. This is required in training.
            If False, do not trace the computation.
            Default: True
        """
100
101
102
        if self.emb.context != idx.context:
            idx = idx.as_in_context(self.emb.context)
        data = nd.take(self.emb, idx)
103
104
        if gpu_id >= 0:
            data = data.as_in_context(mx.gpu(gpu_id))
105
106
107
108
109
        data.attach_grad()
        if trace:
            self.trace.append((idx, data))
        return data

110
    def update(self, gpu_id=-1):
111
        """ Update embeddings in a sparse manner
112
        Sparse embeddings are updated in mini batches. We maintain gradient states for
113
114
115
116
117
118
119
        each embedding so they can be updated separately.

        Parameters
        ----------
        gpu_id : int
            Which gpu to accelerate the calculation. if -1 is provided, cpu is used.
        """
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.state_step += 1
        for idx, data in self.trace:
            grad = data.grad

            clr = self.args.lr
            #clr = self.args.lr / (1 + (self.state_step - 1) * group['lr_decay'])

            # the update is non-linear so indices must be unique
            grad_indices = idx
            grad_values = grad

            grad_sum = (grad_values * grad_values).mean(1)
            ctx = self.state_sum.context
            if ctx != grad_indices.context:
                grad_indices = grad_indices.as_in_context(ctx)
            if ctx != grad_sum.context:
                grad_sum = grad_sum.as_in_context(ctx)
            self.state_sum[grad_indices] += grad_sum
            std = self.state_sum[grad_indices]  # _sparse_mask
139
140
            if gpu_id >= 0:
                std = std.as_in_context(mx.gpu(gpu_id))
141
142
143
144
145
146
147
148
149
            std_values = nd.expand_dims(nd.sqrt(std) + 1e-10, 1)
            tmp = (-clr * grad_values / std_values)
            if tmp.context != ctx:
                tmp = tmp.as_in_context(ctx)
            # TODO(zhengda) the overhead is here.
            self.emb[grad_indices] = mx.nd.take(self.emb, grad_indices) + tmp
        self.trace = []

    def curr_emb(self):
150
151
        """Return embeddings in trace.
        """
152
153
154
155
        data = [data for _, data in self.trace]
        return nd.concat(*data, dim=0)

    def save(self, path, name):
156
157
158
159
160
161
162
163
164
        """Save embeddings.

        Parameters
        ----------
        path : str
            Directory to save the embedding.
        name : str
            Embedding name.
        """
165
166
        emb_fname = os.path.join(path, name+'.npy')
        np.save(emb_fname, self.emb.asnumpy())
167
168

    def load(self, path, name):
169
170
171
172
173
174
175
176
177
        """Load embeddings.

        Parameters
        ----------
        path : str
            Directory to load the embedding.
        name : str
            Embedding name.
        """
178
179
        emb_fname = os.path.join(path, name+'.npy')
        self.emb = nd.array(np.load(emb_fname))