Commit 30580a35 authored by Andrew M Dai's avatar Andrew M Dai Committed by Ryan Sepassi
Browse files

Increase minimum TF version for DEFINE_enum and rename variable mappings for...

Increase minimum TF version for DEFINE_enum and rename variable mappings for change to RNN variable names. (#4123)
parent fea6447a
...@@ -5,7 +5,7 @@ ______*](https://arxiv.org/abs/1801.07736) published at ICLR 2018. ...@@ -5,7 +5,7 @@ ______*](https://arxiv.org/abs/1801.07736) published at ICLR 2018.
## Requirements ## Requirements
* TensorFlow >= v1.3 * TensorFlow >= v1.5
## Instructions ## Instructions
......
...@@ -163,52 +163,48 @@ def rnn_zaremba(hparams, model): ...@@ -163,52 +163,48 @@ def rnn_zaremba(hparams, model):
if v.op.name == str(model) + '/rnn/embedding' if v.op.name == str(model) + '/rnn/embedding'
][0] ][0]
lstm_w_0 = [ lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name == str(model) +
if v.op.name == '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
str(model) + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
lstm_b_0 = [ lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name == str(model) +
if v.op.name == '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
str(model) + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
lstm_w_1 = [ lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name == str(model) +
if v.op.name == '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
str(model) + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
lstm_b_1 = [ lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name == str(model) +
if v.op.name == '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
str(model) + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
# Dictionary mapping. # Dictionary mapping.
if model == 'gen': if model == 'gen':
variable_mapping = { variable_mapping = {
'Model/embedding': embedding, 'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1,
'Model/softmax_w': softmax_w, 'Model/softmax_w': softmax_w,
'Model/softmax_b': softmax_b 'Model/softmax_b': softmax_b
} }
else: else:
if FLAGS.dis_share_embedding: if FLAGS.dis_share_embedding:
variable_mapping = { variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1
} }
else: else:
variable_mapping = { variable_mapping = {
'Model/embedding': embedding, 'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1
} }
return variable_mapping return variable_mapping
...@@ -356,24 +352,20 @@ def gen_encoder_seq2seq(hparams): ...@@ -356,24 +352,20 @@ def gen_encoder_seq2seq(hparams):
if v.op.name == 'gen/encoder/rnn/embedding' if v.op.name == 'gen/encoder/rnn/embedding'
][0] ][0]
encoder_lstm_w_0 = [ encoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_0 = [ encoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
encoder_lstm_w_1 = [ encoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_1 = [ encoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
if FLAGS.data_set == 'ptb': if FLAGS.data_set == 'ptb':
...@@ -385,24 +377,24 @@ def gen_encoder_seq2seq(hparams): ...@@ -385,24 +377,24 @@ def gen_encoder_seq2seq(hparams):
variable_mapping = { variable_mapping = {
str(model_str) + '/embedding': str(model_str) + '/embedding':
encoder_embedding, encoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0, encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0, encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1, encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1 encoder_lstm_b_1
} }
else: else:
variable_mapping = { variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0, encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0, encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1, encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1 encoder_lstm_b_1
} }
return variable_mapping return variable_mapping
...@@ -418,24 +410,20 @@ def gen_decoder_seq2seq(hparams): ...@@ -418,24 +410,20 @@ def gen_decoder_seq2seq(hparams):
if v.op.name == 'gen/decoder/rnn/embedding' if v.op.name == 'gen/decoder/rnn/embedding'
][0] ][0]
decoder_lstm_w_0 = [ decoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_0 = [ decoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
decoder_lstm_w_1 = [ decoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_1 = [ decoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
decoder_softmax_b = [ decoder_softmax_b = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
...@@ -450,13 +438,13 @@ def gen_decoder_seq2seq(hparams): ...@@ -450,13 +438,13 @@ def gen_decoder_seq2seq(hparams):
variable_mapping = { variable_mapping = {
str(model_str) + '/embedding': str(model_str) + '/embedding':
decoder_embedding, decoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0, decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0, decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1, decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1, decoder_lstm_b_1,
str(model_str) + '/softmax_b': str(model_str) + '/softmax_b':
decoder_softmax_b decoder_softmax_b
...@@ -487,34 +475,34 @@ def dis_fwd_bidirectional(hparams): ...@@ -487,34 +475,34 @@ def dis_fwd_bidirectional(hparams):
][0] ][0]
fw_lstm_w_0 = [ fw_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/weights' if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0] ][0]
fw_lstm_b_0 = [ fw_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/biases' if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0] ][0]
fw_lstm_w_1 = [ fw_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/weights' if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0] ][0]
fw_lstm_b_1 = [ fw_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/biases' if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0] ][0]
if FLAGS.dis_share_embedding: if FLAGS.dis_share_embedding:
variable_mapping = { variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': fw_lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': fw_lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': fw_lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': fw_lstm_b_1 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1
} }
else: else:
variable_mapping = { variable_mapping = {
'Model/embedding': embedding, 'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': fw_lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': fw_lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': fw_lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': fw_lstm_b_1 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1
} }
return variable_mapping return variable_mapping
...@@ -537,26 +525,26 @@ def dis_bwd_bidirectional(hparams): ...@@ -537,26 +525,26 @@ def dis_bwd_bidirectional(hparams):
# Backward Discriminator Elements. # Backward Discriminator Elements.
bw_lstm_w_0 = [ bw_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/weights' if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0] ][0]
bw_lstm_b_0 = [ bw_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/biases' if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0] ][0]
bw_lstm_w_1 = [ bw_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/weights' if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0] ][0]
bw_lstm_b_1 = [ bw_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/biases' if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0] ][0]
variable_mapping = { variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': bw_lstm_w_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': bw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': bw_lstm_b_0, 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': bw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': bw_lstm_w_1, 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': bw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': bw_lstm_b_1 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': bw_lstm_b_1
} }
return variable_mapping return variable_mapping
...@@ -576,24 +564,20 @@ def dis_encoder_seq2seq(hparams): ...@@ -576,24 +564,20 @@ def dis_encoder_seq2seq(hparams):
## Encoder forward variables. ## Encoder forward variables.
encoder_lstm_w_0 = [ encoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_0 = [ encoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
encoder_lstm_w_1 = [ encoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_1 = [ encoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
if FLAGS.data_set == 'ptb': if FLAGS.data_set == 'ptb':
...@@ -602,13 +586,13 @@ def dis_encoder_seq2seq(hparams): ...@@ -602,13 +586,13 @@ def dis_encoder_seq2seq(hparams):
model_str = 'model' model_str = 'model'
variable_mapping = { variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0, encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0, encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1, encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1 encoder_lstm_b_1
} }
return variable_mapping return variable_mapping
...@@ -624,24 +608,20 @@ def dis_decoder_seq2seq(hparams): ...@@ -624,24 +608,20 @@ def dis_decoder_seq2seq(hparams):
if v.op.name == 'dis/decoder/rnn/embedding' if v.op.name == 'dis/decoder/rnn/embedding'
][0] ][0]
decoder_lstm_w_0 = [ decoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_0 = [ decoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
decoder_lstm_w_1 = [ decoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_1 = [ decoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
if FLAGS.data_set == 'ptb': if FLAGS.data_set == 'ptb':
...@@ -653,24 +633,24 @@ def dis_decoder_seq2seq(hparams): ...@@ -653,24 +633,24 @@ def dis_decoder_seq2seq(hparams):
variable_mapping = { variable_mapping = {
str(model_str) + '/embedding': str(model_str) + '/embedding':
decoder_embedding, decoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0, decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0, decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1, decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1 decoder_lstm_b_1
} }
else: else:
variable_mapping = { variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0, decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0, decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1, decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1, decoder_lstm_b_1,
} }
return variable_mapping return variable_mapping
...@@ -688,24 +668,20 @@ def dis_seq2seq_vd(hparams): ...@@ -688,24 +668,20 @@ def dis_seq2seq_vd(hparams):
## Encoder variables. ## Encoder variables.
encoder_lstm_w_0 = [ encoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_0 = [ encoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
encoder_lstm_w_1 = [ encoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
encoder_lstm_b_1 = [ encoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
## Attention. ## Attention.
...@@ -721,43 +697,39 @@ def dis_seq2seq_vd(hparams): ...@@ -721,43 +697,39 @@ def dis_seq2seq_vd(hparams):
## Decoder. ## Decoder.
decoder_lstm_w_0 = [ decoder_lstm_w_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_0 = [ decoder_lstm_b_0 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][0] ][0]
decoder_lstm_w_1 = [ decoder_lstm_w_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][0] ][0]
decoder_lstm_b_1 = [ decoder_lstm_b_1 = [
v for v in tf.trainable_variables() v for v in tf.trainable_variables() if v.op.name ==
if v.op.name == 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][0] ][0]
# Standard variable mappings. # Standard variable mappings.
variable_mapping = { variable_mapping = {
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights': 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0, encoder_lstm_w_0,
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases': 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0, encoder_lstm_b_0,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights': 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1, encoder_lstm_w_1,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases': 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1, encoder_lstm_b_1,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights': 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0, decoder_lstm_w_0,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases': 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0, decoder_lstm_b_0,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights': 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1, decoder_lstm_w_1,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases': 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1 decoder_lstm_b_1
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment