Commit b9665e9b authored by Menglong Zhu's avatar Menglong Zhu Committed by dreamdragon
Browse files

Allowing the option to produce LSTM outputs with the bottleneck feature map concatenated.

PiperOrigin-RevId: 213873057
parent d0c1b9da
......@@ -45,19 +45,22 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
forget_bias=1.0,
activation=tf.tanh,
flattened_state=False,
output_bottleneck=False,
visualize_gates=True):
"""Initializes the basic LSTM cell.
Args:
filter_size: collection, conv filter size
output_size: collection, the width/height dimensions of the cell/output
filter_size: collection, conv filter size.
output_size: collection, the width/height dimensions of the cell/output.
num_units: int, The number of channels in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
activation: Activation function of the inner states.
flattened_state: if True, state tensor will be flattened and stored as
a 2-d tensor. Use for exporting the model to tfmini
a 2-d tensor. Use for exporting the model to tfmini.
output_bottleneck: if True, the cell bottleneck will be concatenated
to the cell output.
visualize_gates: if True, add histogram summaries of all gates
and outputs to tensorboard
and outputs to tensorboard.
"""
self._filter_size = list(filter_size)
self._output_size = list(output_size)
......@@ -66,6 +69,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
self._activation = activation
self._viz_gates = visualize_gates
self._flattened_state = flattened_state
self._output_bottleneck = output_bottleneck
self._param_count = self._num_units
for dim in self._output_size:
self._param_count *= dim
......@@ -99,7 +103,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
with tf.variable_scope(scope):
c, h = state
# unflatten state if neccesary
# unflatten state if necessary
if self._flattened_state:
c = tf.reshape(c, [-1] + self.output_size)
h = tf.reshape(h, [-1] + self.output_size)
......@@ -140,13 +144,16 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
slim.summaries.add_histogram_summary(new_h, 'cell_output')
slim.summaries.add_histogram_summary(new_c, 'cell_state')
output = new_h
if self._output_bottleneck:
output = tf.concat([new_h, bottleneck], axis=3)
# reflatten state to store it
if self._flattened_state:
new_c = tf.reshape(new_c, [-1, self._param_count])
new_h = tf.reshape(new_h, [-1, self._param_count])
return new_h, tf.contrib.rnn.LSTMStateTuple(
new_c, new_h if self._flattened_state else new_h)
return output, tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
def init_state(self, state_name, batch_size, dtype, learned_state=False):
"""Creates an initial state compatible with this cell.
......
......@@ -66,10 +66,33 @@ class BottleneckConvLstmCellsTest(tf.test.TestCase):
init_state = cell.init_state(
state_name, batch_size, dtype, learned_state)
output, state_tuple = cell(inputs, init_state)
self.assertAllEqual([4, 1500], output.shape.as_list())
self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
self.assertAllEqual([4, 1500], state_tuple[0].shape.as_list())
self.assertAllEqual([4, 1500], state_tuple[1].shape.as_list())
def test_run_lstm_cell_with_output_bottleneck(self):
filter_size = [3, 3]
output_dim = 10
output_size = [output_dim] * 2
num_units = 15
state_name = 'lstm_state'
batch_size = 4
dtype = tf.float32
learned_state = False
inputs = tf.zeros([batch_size, output_dim, output_dim, 3], dtype=tf.float32)
cell = lstm_cells.BottleneckConvLSTMCell(
filter_size=filter_size,
output_size=output_size,
num_units=num_units,
output_bottleneck=True)
init_state = cell.init_state(
state_name, batch_size, dtype, learned_state)
output, state_tuple = cell(inputs, init_state)
self.assertAllEqual([4, 10, 10, 30], output.shape.as_list())
self.assertAllEqual([4, 10, 10, 15], state_tuple[0].shape.as_list())
self.assertAllEqual([4, 10, 10, 15], state_tuple[1].shape.as_list())
def test_get_init_state(self):
filter_size = [3, 3]
output_dim = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment