nn_ops.py 8.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Ops and utilities for neural networks.

For now, just an LSTM layer.
"""
import shapes
import tensorflow as tf
rnn = tf.load_op_library("../cc/rnn_ops.so")


def rnn_helper(inp,
               length,
               cell_type=None,
               direction="forward",
               name=None,
               *args,
               **kwargs):
  """Adds ops for a recurrent neural network layer.

  This function calls an actual implementation of a recurrent neural network
  based on `cell_type`.

  There are three modes depending on the value of `direction`:

    forward: Adds a forward RNN.
    backward: Adds a backward RNN.
    bidirectional: Adds both forward and backward RNNs and creates a
                   bidirectional RNN.

  Args:
    inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
    length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
            represents the length of the corresponding sequence in `inp`.
    cell_type: Cell type of RNN. Currently can only be "lstm".
    direction: One of "forward", "backward", "bidirectional".
    name: Name of the op.
    *args: Other arguments to the layer.
    **kwargs: Keyword arugments to the layer.

  Returns:
    A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
  """

  assert cell_type is not None
  rnn_func = None
  if cell_type == "lstm":
    rnn_func = lstm_layer
  assert rnn_func is not None
  assert direction in ["forward", "backward", "bidirectional"]

  with tf.variable_scope(name):
    if direction in ["forward", "bidirectional"]:
      forward = rnn_func(
          inp=inp,
          length=length,
          backward=False,
          name="forward",
          *args,
          **kwargs)
      if isinstance(forward, tuple):
        # lstm_layer returns a tuple (output, memory). We only need the first
        # element.
        forward = forward[0]
    if direction in ["backward", "bidirectional"]:
      backward = rnn_func(
          inp=inp,
          length=length,
          backward=True,
          name="backward",
          *args,
          **kwargs)
      if isinstance(backward, tuple):
        # lstm_layer returns a tuple (output, memory). We only need the first
        # element.
        backward = backward[0]
    if direction == "forward":
      out = forward
    elif direction == "backward":
      out = backward
    else:
95
      out = tf.concat(axis=2, values=[forward, backward])
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
  return out


@tf.RegisterShape("VariableLSTM")
def _variable_lstm_shape(op):
  """Shape function for the VariableLSTM op."""
  input_shape = op.inputs[0].get_shape().with_rank(4)
  state_shape = op.inputs[1].get_shape().with_rank(2)
  memory_shape = op.inputs[2].get_shape().with_rank(2)
  w_m_m_shape = op.inputs[3].get_shape().with_rank(3)
  batch_size = input_shape[0].merge_with(state_shape[0])
  batch_size = input_shape[0].merge_with(memory_shape[0])
  seq_len = input_shape[1]
  gate_num = input_shape[2].merge_with(w_m_m_shape[1])
  output_dim = input_shape[3].merge_with(state_shape[1])
  output_dim = output_dim.merge_with(memory_shape[1])
  output_dim = output_dim.merge_with(w_m_m_shape[0])
  output_dim = output_dim.merge_with(w_m_m_shape[2])
  return [[batch_size, seq_len, output_dim],
          [batch_size, seq_len, gate_num, output_dim],
          [batch_size, seq_len, output_dim]]


@tf.RegisterGradient("VariableLSTM")
def _variable_lstm_grad(op, act_grad, gate_grad, mem_grad):
  """Gradient function for the VariableLSTM op."""
  initial_state = op.inputs[1]
  initial_memory = op.inputs[2]
  w_m_m = op.inputs[3]
  act = op.outputs[0]
  gate_raw_act = op.outputs[1]
  memory = op.outputs[2]
  return rnn.variable_lstm_grad(initial_state, initial_memory, w_m_m, act,
                                gate_raw_act, memory, act_grad, gate_grad,
                                mem_grad)


def lstm_layer(inp,
               length=None,
               state=None,
               memory=None,
               num_nodes=None,
               backward=False,
               clip=50.0,
               reg_func=tf.nn.l2_loss,
               weight_reg=False,
               weight_collection="LSTMWeights",
               bias_reg=False,
               stddev=None,
               seed=None,
               decode=False,
               use_native_weights=False,
               name=None):
  """Adds ops for an LSTM layer.

  This adds ops for the following operations:

    input => (forward-LSTM|backward-LSTM) => output

  The direction of the LSTM is determined by `backward`. If it is false, the
  forward LSTM is used, the backward one otherwise.

  Args:
    inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
    length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
            represents the length of the corresponding sequence in `inp`.
    state: If specified, uses it as the initial state.
    memory: If specified, uses it as the initial memory.
    num_nodes: The number of LSTM cells.
    backward: If true, reverses the `inp` before adding the ops. The output is
              also reversed so that the direction is the same as `inp`.
    clip: Value used to clip the cell values.
    reg_func: Function used for the weight regularization such as
              `tf.nn.l2_loss`.
    weight_reg: If true, regularize the filter weights with `reg_func`.
    weight_collection: Collection to add the weights to for regularization.
    bias_reg: If true, regularize the bias vector with `reg_func`.
    stddev: Standard deviation used to initialize the variables.
    seed: Seed used to initialize the variables.
    decode: If true, does not add ops which are not used for inference.
    use_native_weights: If true, uses weights in the same format as the native
                        implementations.
    name: Name of the op.

  Returns:
    A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
  """
  with tf.variable_scope(name):
    if backward:
      if length is None:
186
        inp = tf.reverse(inp, [1])
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
      else:
        inp = tf.reverse_sequence(inp, length, 1, 0)

    num_prev = inp.get_shape()[2]
    if stddev:
      initializer = tf.truncated_normal_initializer(stddev=stddev, seed=seed)
    else:
      initializer = tf.uniform_unit_scaling_initializer(seed=seed)

    if use_native_weights:
      with tf.variable_scope("LSTMCell"):
        w = tf.get_variable(
            "W_0",
            shape=[num_prev + num_nodes, 4 * num_nodes],
            initializer=initializer,
            dtype=tf.float32)
        w_i_m = tf.slice(w, [0, 0], [num_prev, 4 * num_nodes], name="w_i_m")
        w_m_m = tf.reshape(
            tf.slice(w, [num_prev, 0], [num_nodes, 4 * num_nodes]),
            [num_nodes, 4, num_nodes],
            name="w_m_m")
    else:
      w_i_m = tf.get_variable("w_i_m", [num_prev, 4 * num_nodes],
                              initializer=initializer)
      w_m_m = tf.get_variable("w_m_m", [num_nodes, 4, num_nodes],
                              initializer=initializer)

    if not decode and weight_reg:
      tf.add_to_collection(weight_collection, reg_func(w_i_m, name="w_i_m_reg"))
      tf.add_to_collection(weight_collection, reg_func(w_m_m, name="w_m_m_reg"))

    batch_size = shapes.tensor_dim(inp, dim=0)
    num_frames = shapes.tensor_dim(inp, dim=1)
220
    prev = tf.reshape(inp, tf.stack([batch_size * num_frames, num_prev]))
221
222
223
224
225
226

    if use_native_weights:
      with tf.variable_scope("LSTMCell"):
        b = tf.get_variable(
            "B",
            shape=[4 * num_nodes],
227
            initializer=tf.zeros_initializer(),
228
229
230
231
232
233
234
235
236
237
238
            dtype=tf.float32)
      biases = tf.identity(b, name="biases")
    else:
      biases = tf.get_variable(
          "biases", [4 * num_nodes], initializer=tf.constant_initializer(0.0))
    if not decode and bias_reg:
      tf.add_to_collection(
          weight_collection, reg_func(
              biases, name="biases_reg"))
    prev = tf.nn.xw_plus_b(prev, w_i_m, biases)

239
    prev = tf.reshape(prev, tf.stack([batch_size, num_frames, 4, num_nodes]))
240
    if state is None:
241
      state = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
242
    if memory is None:
243
      memory = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
244
245
246
247
248

    out, _, mem = rnn.variable_lstm(prev, state, memory, w_m_m, clip=clip)

    if backward:
      if length is None:
249
        out = tf.reverse(out, [1])
250
251
252
253
      else:
        out = tf.reverse_sequence(out, length, 1, 0)

  return out, mem