"magic_pdf/vscode:/vscode.git/clone" did not exist on "8a179269595417cd26a3ecf98ebc8310340552ff"
Commit b9665e9b authored by Menglong Zhu's avatar Menglong Zhu Committed by dreamdragon
Browse files

Allowing the option to produce LSTM outputs with the bottleneck feature map concatenated.

PiperOrigin-RevId: 213873057
parent d0c1b9da
......@@ -45,19 +45,22 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
forget_bias=1.0,
activation=tf.tanh,
flattened_state=False,
output_bottleneck=False,
visualize_gates=True):
"""Initializes the basic LSTM cell.
Args:
filter_size: collection, conv filter size
output_size: collection, the width/height dimensions of the cell/output
filter_size: collection, conv filter size.
output_size: collection, the width/height dimensions of the cell/output.
num_units: int, The number of channels in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
activation: Activation function of the inner states.
flattened_state: if True, state tensor will be flattened and stored as
a 2-d tensor. Use for exporting the model to tfmini
a 2-d tensor. Use for exporting the model to tfmini.
output_bottleneck: if True, the cell bottleneck will be concatenated
to the cell output.
visualize_gates: if True, add histogram summaries of all gates
and outputs to tensorboard
and outputs to tensorboard.
"""
self._filter_size = list(filter_size)
self._output_size = list(output_size)
......@@ -66,6 +69,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
self._activation = activation
self._viz_gates = visualize_gates
self._flattened_state = flattened_state
self._output_bottleneck = output_bottleneck
self._param_count = self._num_units
for dim in self._output_size:
self._param_count *= dim
......@@ -99,7 +103,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
with tf.variable_scope(scope):
c, h = state
# unflatten state if neccesary
# unflatten state if necessary
if self._flattened_state:
c = tf.reshape(c, [-1] + self.output_size)
h = tf.reshape(h, [-1] + self.output_size)
......@@ -140,13 +144,16 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
slim.summaries.add_histogram_summary(new_h, 'cell_output')
slim.summaries.add_histogram_summary(new_c, 'cell_state')
output = new_h
if self._output_bottleneck:
output = tf.concat([new_h, bottleneck], axis=3)
# reflatten state to store it
if self._flattened_state:
new_c = tf.reshape(new_c, [-1, self._param_count])
new_h = tf.reshape(new_h, [-1, self._param_count])
return new_h, tf.contrib.rnn.LSTMStateTuple(
new_c, new_h if self._flattened_state else new_h)
return output, tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
def init_state(self, state_name, batch_size, dtype, learned_state=False):
"""Creates an initial state compatible with this cell.
......
......@@ -66,10 +66,33 @@ class BottleneckConvLstmCellsTest(tf.test.TestCase):
init_state = cell.init_state(
state_name, batch_size, dtype, learned_state)
output, state_tuple = cell(inputs, init_state)
self.assertAllEqual([4, 1500], output.shape.as_list())
self.assertAllEqual([4, 10, 10, 15], output.shape.as_list())
self.assertAllEqual([4, 1500], state_tuple[0].shape.as_list())
self.assertAllEqual([4, 1500], state_tuple[1].shape.as_list())
def test_run_lstm_cell_with_output_bottleneck(self):
filter_size = [3, 3]
output_dim = 10
output_size = [output_dim] * 2
num_units = 15
state_name = 'lstm_state'
batch_size = 4
dtype = tf.float32
learned_state = False
inputs = tf.zeros([batch_size, output_dim, output_dim, 3], dtype=tf.float32)
cell = lstm_cells.BottleneckConvLSTMCell(
filter_size=filter_size,
output_size=output_size,
num_units=num_units,
output_bottleneck=True)
init_state = cell.init_state(
state_name, batch_size, dtype, learned_state)
output, state_tuple = cell(inputs, init_state)
self.assertAllEqual([4, 10, 10, 30], output.shape.as_list())
self.assertAllEqual([4, 10, 10, 15], state_tuple[0].shape.as_list())
self.assertAllEqual([4, 10, 10, 15], state_tuple[1].shape.as_list())
def test_get_init_state(self):
filter_size = [3, 3]
output_dim = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment