"tests/vscode:/vscode.git/clone" did not exist on "425a715e35479338c06b2a68eb3a95790c1db3c5"
Commit 0d8916f4 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Revert 0.12 changes but keep the argument swap

parent a38bf8d7
...@@ -100,7 +100,7 @@ class Seq2SeqModel(object): ...@@ -100,7 +100,7 @@ class Seq2SeqModel(object):
b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype) b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
output_projection = (w, b) output_projection = (w, b)
def sampled_loss(inputs,labels): def sampled_loss(inputs, labels):
labels = tf.reshape(labels, [-1, 1]) labels = tf.reshape(labels, [-1, 1])
# We need to compute the sampled_softmax_loss using 32bit floats to # We need to compute the sampled_softmax_loss using 32bit floats to
# avoid numerical instabilities. # avoid numerical instabilities.
...@@ -120,17 +120,17 @@ class Seq2SeqModel(object): ...@@ -120,17 +120,17 @@ class Seq2SeqModel(object):
# Create the internal multi-layer cell for our RNN. # Create the internal multi-layer cell for our RNN.
def single_cell(): def single_cell():
return tf.nn.rnn_cell.GRUCell(size) return tf.contrib.rnn.GRUCell(size)
if use_lstm: if use_lstm:
def single_cell(): def single_cell():
return tf.nn.rnn_cell.BasicLSTMCell(size) return tf.contrib.rnn.BasicLSTMCell(size)
cell = single_cell() cell = single_cell()
if num_layers > 1: if num_layers > 1:
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell() for _ in range(num_layers)]) cell = tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)])
# The seq2seq function: we use embedding for the input and attention. # The seq2seq function: we use embedding for the input and attention.
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
return tf.nn.seq2seq.embedding_attention_seq2seq( return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
encoder_inputs, encoder_inputs,
decoder_inputs, decoder_inputs,
cell, cell,
...@@ -160,7 +160,7 @@ class Seq2SeqModel(object): ...@@ -160,7 +160,7 @@ class Seq2SeqModel(object):
# Training outputs and losses. # Training outputs and losses.
if forward_only: if forward_only:
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets( self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets, self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True), self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
softmax_loss_function=softmax_loss_function) softmax_loss_function=softmax_loss_function)
...@@ -172,7 +172,7 @@ class Seq2SeqModel(object): ...@@ -172,7 +172,7 @@ class Seq2SeqModel(object):
for output in self.outputs[b] for output in self.outputs[b]
] ]
else: else:
self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets( self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
self.encoder_inputs, self.decoder_inputs, targets, self.encoder_inputs, self.decoder_inputs, targets,
self.target_weights, buckets, self.target_weights, buckets,
lambda x, y: seq2seq_f(x, y, False), lambda x, y: seq2seq_f(x, y, False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment