graph_to_tf.py 13.5 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import tensorflow as tf
from rnn import XGRUCell
from util import dropout
from graph import LayerType


def normalize(inputs,
              epsilon=1e-8,
              scope="ln"):
    '''Applies layer normalization.

    Args:
      inputs: A tensor with 2 or more dimensions, where the first dimension has
        `batch_size`.
      epsilon: A floating number. A very small number for preventing ZeroDivision Error.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns:
      A tensor with the same shape and data dtype as `inputs`.
    '''
    with tf.variable_scope(scope):
        inputs_shape = inputs.get_shape()
        params_shape = inputs_shape[-1:]

        mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
        beta = tf.Variable(tf.zeros(params_shape))
        gamma = tf.Variable(tf.ones(params_shape))
        normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
        outputs = gamma * normalized + beta

    return outputs


def multihead_attention(queries,
                        keys,
                        scope="multihead_attention",
                        num_units=None,
                        num_heads=4,
                        dropout_rate=0,
                        is_training=True,
                        causality=False):
    '''Applies multihead attention.

    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q].
      keys: A 3d tensor with shape of [N, T_k, C_k].
      num_units: A cdscalar. Attention size.
      dropout_rate: A floating point number.
      is_training: Boolean. Controller of mechanism for dropout.
      causality: Boolean. If true, units that reference the future are masked.
      num_heads: An int. Number of heads.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns
      A 3d tensor with shape of (N, T_q, C)
    '''
    global look5
    with tf.variable_scope(scope):
        # Set the fall back option for num_units
        if num_units is None:
            num_units = queries.get_shape().as_list()[-1]

        Q_ = []
        K_ = []
        V_ = []
        for _ in range(num_heads):
            Q = tf.layers.dense(queries, num_units / num_heads,
                                activation=tf.nn.relu)  # (N, T_q, C)
            K = tf.layers.dense(keys, num_units / num_heads,
                                activation=tf.nn.relu)  # (N, T_k, C)
            V = tf.layers.dense(keys, num_units / num_heads,
                                activation=tf.nn.relu)  # (N, T_k, C)
            Q_.append(Q)
            K_.append(K)
            V_.append(V)

        # Split and concat
        Q_ = tf.concat(Q_, axis=0)  # (h*N, T_q, C/h)
        K_ = tf.concat(K_, axis=0)  # (h*N, T_k, C/h)
        V_ = tf.concat(V_, axis=0)  # (h*N, T_k, C/h)

        # Multiplication
        outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)

        # Scale
        outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

        # Key Masking
        key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1)))  # (N, T_k)
        key_masks = tf.tile(key_masks, [num_heads, 1])  # (h*N, T_k)
        key_masks = tf.tile(tf.expand_dims(key_masks, 1),
                            [1, tf.shape(queries)[1], 1])  # (h*N, T_q, T_k)

        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(key_masks, 0), paddings,
                           outputs)  # (h*N, T_q, T_k)

        # Causality = Future blinding
        if causality:
            diag_vals = tf.ones_like(outputs[0, :, :])  # (T_q, T_k)
            tril = tf.contrib.linalg.LinearOperatorTriL(
                diag_vals).to_dense()  # (T_q, T_k)
            masks = tf.tile(tf.expand_dims(tril, 0),
                            [tf.shape(outputs)[0], 1, 1])  # (h*N, T_q, T_k)

            paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
            outputs = tf.where(tf.equal(masks, 0), paddings,
                               outputs)  # (h*N, T_q, T_k)

        # Activation
        look5 = outputs
        outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = tf.sign(
            tf.abs(tf.reduce_sum(queries, axis=-1)))  # (N, T_q)
        query_masks = tf.tile(query_masks, [num_heads, 1])  # (h*N, T_q)
        query_masks = tf.tile(tf.expand_dims(
            query_masks, -1), [1, 1, tf.shape(keys)[1]])  # (h*N, T_q, T_k)
        outputs *= query_masks  # broadcasting. (N, T_q, C)

        # Dropouts
        outputs = dropout(outputs, dropout_rate, is_training)

        # Weighted sum
        outputs = tf.matmul(outputs, V_)  # ( h*N, T_q, C/h)

        # Restore shape
        outputs = tf.concat(tf.split(outputs, num_heads,
                                     axis=0), axis=2)  # (N, T_q, C)

        # Residual connection
        if queries.get_shape().as_list()[-1] == num_units:
            outputs += queries

        # Normalize
        outputs = normalize(outputs, scope=scope)  # (N, T_q, C)

    return outputs


def positional_encoding(inputs,
                        num_units=None,
                        zero_pad=True,
                        scale=True,
                        scope="positional_encoding",
                        reuse=None):
    '''
    Return positinal embedding.
    '''
    Shape = tf.shape(inputs)
    N = Shape[0]
    T = Shape[1]
    num_units = Shape[2]
    with tf.variable_scope(scope, reuse=reuse):
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1])

        # First part of the PE function: sin and cos argument
        #  Second part, apply the cosine to even columns and sin to odds.
        X = tf.expand_dims(tf.cast(tf.range(T), tf.float32), axis=1)
        Y = tf.expand_dims(
            tf.cast(10000 ** -(2 * tf.range(num_units) / num_units), tf.float32), axis=0)
        h1 = tf.cast((tf.range(num_units) + 1) % 2, tf.float32)
        h2 = tf.cast((tf.range(num_units) % 2), tf.float32)
        position_enc = tf.multiply(X, Y)
        position_enc = tf.sin(position_enc) * tf.multiply(tf.ones_like(X), h1) + \
            tf.cos(position_enc) * tf.multiply(tf.ones_like(X), h2)

        # Convert to a tensor
        lookup_table = position_enc

        if zero_pad:
            lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
                                      lookup_table[1:, :]), 0)
        outputs = tf.nn.embedding_lookup(lookup_table, position_ind)

        if scale:
            outputs = outputs * tf.sqrt(tf.cast(num_units, tf.float32))

        return outputs


def feedforward(inputs,
                num_units,
                scope="multihead_attention"):
    '''Point-wise feed forward net.

    Args:
      inputs: A 3d tensor with shape of [N, T, C].
      num_units: A list of two integers.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns:
      A 3d tensor with the same shape and dtype as inputs
    '''
    with tf.variable_scope(scope):
        # Inner layer
        params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1,
                  "activation": tf.nn.relu, "use_bias": True}
        outputs = tf.layers.conv1d(**params)

        # Readout layer
        params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1,
                  "activation": None, "use_bias": True}
        outputs = tf.layers.conv1d(**params)

        # Residual connection
        outputs += inputs

        # Normalize
        outputs = normalize(outputs)

    return outputs


def rnn(input_states, sequence_lengths, dropout_rate, is_training, num_units):
    layer_cnt = 1
    states = []
    xs = tf.transpose(input_states, perm=[1, 0, 2])
    for i in range(0, layer_cnt):
        xs = dropout(xs, dropout_rate, is_training)
        with tf.variable_scope('layer_' + str(i)):
            cell_fw = XGRUCell(num_units)
            cell_bw = XGRUCell(num_units)
            outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=cell_fw,
                cell_bw=cell_bw,
                dtype=tf.float32,
                sequence_length=sequence_lengths,
                inputs=xs,
                time_major=True)

        y_lr, y_rl = outputs
        xs = tf.concat([y_lr, y_rl], 2)
        states.append(xs)

    return tf.transpose(dropout(tf.concat(states, axis=2),
                                dropout_rate,
                                is_training), perm=[1, 0, 2])


def graph_to_network(input1,
                     input2,
                     input1_lengths,
                     input2_lengths,
                     graph,
                     dropout_rate,
                     is_training,
                     num_heads=1,
                     rnn_units=256):
    topology = graph.is_topology()
    layers = dict()
    layers_sequence_lengths = dict()
    num_units = input1.get_shape().as_list()[-1]
    layers[0] = input1*tf.sqrt(tf.cast(num_units, tf.float32)) + \
        positional_encoding(input1, scale=False, zero_pad=False)
    layers[1] = input2*tf.sqrt(tf.cast(num_units, tf.float32))
    layers[0] = dropout(layers[0], dropout_rate, is_training)
    layers[1] = dropout(layers[1], dropout_rate, is_training)
    layers_sequence_lengths[0] = input1_lengths
    layers_sequence_lengths[1] = input2_lengths
    for _, topo_i in enumerate(topology):
        if topo_i == '|':
            continue
        if graph.layers[topo_i].graph_type == LayerType.input.value:
            continue
        elif graph.layers[topo_i].graph_type == LayerType.attention.value:
            with tf.variable_scope('attation_%d' % topo_i):
                layer = multihead_attention(layers[graph.layers[topo_i].input[0]],
                                            layers[graph.layers[topo_i].input[1]],
                                            scope="multihead_attention%d" % topo_i,
                                            dropout_rate=dropout_rate,
                                            is_training=is_training,
                                            num_heads=num_heads,
                                            num_units=rnn_units * 2)
                layer = feedforward(layer, scope="feedforward%d" % topo_i,
                                    num_units=[rnn_units * 2 * 4, rnn_units * 2])
            layers[topo_i] = layer
            layers_sequence_lengths[topo_i] = layers_sequence_lengths[
                graph.layers[topo_i].input[0]]
        elif graph.layers[topo_i].graph_type == LayerType.self_attention.value:
            with tf.variable_scope('self-attation_%d' % topo_i):
                layer = multihead_attention(layers[graph.layers[topo_i].input[0]],
                                            layers[graph.layers[topo_i].input[0]],
                                            scope="multihead_attention%d" % topo_i,
                                            dropout_rate=dropout_rate,
                                            is_training=is_training,
                                            num_heads=num_heads,
                                            num_units=rnn_units * 2)
                layer = feedforward(layer, scope="feedforward%d" % topo_i,
                                    num_units=[rnn_units * 2 * 4, rnn_units * 2])
            layers[topo_i] = layer
            layers_sequence_lengths[topo_i] = layers_sequence_lengths[
                graph.layers[topo_i].input[0]]
        elif graph.layers[topo_i].graph_type == LayerType.rnn.value:
            with tf.variable_scope('rnn_%d' % topo_i):
                layer = rnn(layers[graph.layers[topo_i].input[0]],
                            layers_sequence_lengths[graph.layers[topo_i].input[0]],
                            dropout_rate,
                            is_training,
                            rnn_units)
            layers[topo_i] = layer
            layers_sequence_lengths[topo_i] = layers_sequence_lengths[
                graph.layers[topo_i].input[0]]
        elif graph.layers[topo_i].graph_type == LayerType.output.value:
            layers[topo_i] = layers[graph.layers[topo_i].input[0]]
            if layers[topo_i].get_shape().as_list()[-1] != rnn_units * 1 * 2:
                with tf.variable_scope('add_dense'):
                    layers[topo_i] = tf.layers.dense(
                        layers[topo_i], units=rnn_units*2)
    return layers[2], layers[3]