Commit 9b023de8 authored by Xin Pan's avatar Xin Pan Committed by GitHub
Browse files

Merge pull request #601 from panyx0718/master

Explicitly set state_is_tuple=False.
parents 1662e278 5e875226
...@@ -160,10 +160,12 @@ class Seq2SeqAttentionModel(object): ...@@ -160,10 +160,12 @@ class Seq2SeqAttentionModel(object):
self._next_device()): self._next_device()):
cell_fw = tf.nn.rnn_cell.LSTMCell( cell_fw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden, hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123)) initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
state_is_tuple=False)
cell_bw = tf.nn.rnn_cell.LSTMCell( cell_bw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden, hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113)) initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False)
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn( (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32, cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
sequence_length=article_lens) sequence_length=article_lens)
...@@ -188,7 +190,8 @@ class Seq2SeqAttentionModel(object): ...@@ -188,7 +190,8 @@ class Seq2SeqAttentionModel(object):
cell = tf.nn.rnn_cell.LSTMCell( cell = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden, hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113)) initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False)
encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden]) encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
for x in encoder_outputs] for x in encoder_outputs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment