Commit 4b5a0801 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Switch from static_rnn to dynamic_rnn. This significantly reduces CPU/GPU

memory and fixes the OOM errors.

PiperOrigin-RevId: 172374935
parent ed54e8dd
......@@ -95,7 +95,7 @@ class Embedding(K.layers.Layer):
class LSTM(object):
"""LSTM layer using static_rnn.
"""LSTM layer using dynamic_rnn.
Exposes variables in `trainable_weights` property.
"""
......@@ -119,15 +119,11 @@ class LSTM(object):
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
# Convert into a time-major list for static_rnn
x = tf.unstack(tf.transpose(x, perm=[1, 0, 2]))
lstm_out, next_state = tf.contrib.rnn.static_rnn(
lstm_out, next_state = tf.nn.dynamic_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
lstm_out = tf.stack(lstm_out)
# shape(lstm_out) = (timesteps, batch_size, cell_size)
lstm_out = tf.transpose(lstm_out, [1, 0, 2])
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment