Commit 73ae53ac authored by Neal Wu's avatar Neal Wu
Browse files

Replace old tf.nn modules with 1.0-compatible versions

parent b18162b3
......@@ -166,7 +166,7 @@ class Seq2SeqAttentionModel(object):
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False)
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
(emb_encoder_inputs, fw_state, _) = tf.contrib.rnn.static_bidirectional_rnn(
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
sequence_length=article_lens)
encoder_outputs = emb_encoder_inputs
......@@ -200,7 +200,7 @@ class Seq2SeqAttentionModel(object):
# During decoding, follow up _dec_in_state are fed from beam_search.
# dec_out_state are stored by beam_search for next step feeding.
initial_state_attention = (hps.mode == 'decode')
decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
decoder_outputs, self._dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder(
emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
cell, num_heads=1, loop_function=loop_function,
initial_state_attention=initial_state_attention)
......@@ -234,7 +234,7 @@ class Seq2SeqAttentionModel(object):
self._loss = seq2seq_lib.sampled_sequence_loss(
decoder_outputs, targets, loss_weights, sampled_loss_func)
else:
self._loss = tf.nn.seq2seq.sequence_loss(
self._loss = tf.contrib.legacy_seq2seq.sequence_loss(
model_outputs, targets, loss_weights)
tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment