Commit 73ae53ac authored by Neal Wu's avatar Neal Wu
Browse files

Replace old tf.nn modules with 1.0-compatible versions

parent b18162b3
...@@ -166,7 +166,7 @@ class Seq2SeqAttentionModel(object): ...@@ -166,7 +166,7 @@ class Seq2SeqAttentionModel(object):
hps.num_hidden, hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113), initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False) state_is_tuple=False)
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn( (emb_encoder_inputs, fw_state, _) = tf.contrib.rnn.static_bidirectional_rnn(
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32, cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
sequence_length=article_lens) sequence_length=article_lens)
encoder_outputs = emb_encoder_inputs encoder_outputs = emb_encoder_inputs
...@@ -200,7 +200,7 @@ class Seq2SeqAttentionModel(object): ...@@ -200,7 +200,7 @@ class Seq2SeqAttentionModel(object):
# During decoding, follow up _dec_in_state are fed from beam_search. # During decoding, follow up _dec_in_state are fed from beam_search.
# dec_out_state are stored by beam_search for next step feeding. # dec_out_state are stored by beam_search for next step feeding.
initial_state_attention = (hps.mode == 'decode') initial_state_attention = (hps.mode == 'decode')
decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder( decoder_outputs, self._dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder(
emb_decoder_inputs, self._dec_in_state, self._enc_top_states, emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
cell, num_heads=1, loop_function=loop_function, cell, num_heads=1, loop_function=loop_function,
initial_state_attention=initial_state_attention) initial_state_attention=initial_state_attention)
...@@ -234,7 +234,7 @@ class Seq2SeqAttentionModel(object): ...@@ -234,7 +234,7 @@ class Seq2SeqAttentionModel(object):
self._loss = seq2seq_lib.sampled_sequence_loss( self._loss = seq2seq_lib.sampled_sequence_loss(
decoder_outputs, targets, loss_weights, sampled_loss_func) decoder_outputs, targets, loss_weights, sampled_loss_func)
else: else:
self._loss = tf.nn.seq2seq.sequence_loss( self._loss = tf.contrib.legacy_seq2seq.sequence_loss(
model_outputs, targets, loss_weights) model_outputs, targets, loss_weights)
tf.summary.scalar('loss', tf.minimum(12.0, self._loss)) tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment