Unverified Commit 7d16fc45 authored by Andrew M Dai's avatar Andrew M Dai Committed by GitHub
Browse files

Merge pull request #3486 from a-dai/master

Added new MaskGAN model.
parents c5c6eaf2 87fae3f7
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None):
"""Define the Discriminator graph."""
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_fwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
cell_bwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
# print initial_state
# print cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
if initial_state:
state_fwd = [[tf.identity(x) for x in inner_initial_state]
for inner_initial_state in initial_state]
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
else:
state_fwd = cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob,
2 * hparams.dis_rnn_size)
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
rnn_inputs = tf.unstack(rnn_inputs, axis=1)
with tf.variable_scope('rnn') as vs:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
cell_fwd, cell_bwd, rnn_inputs, state_fwd, state_bwd, scope=vs)
if is_training:
outputs *= output_mask
# Prediction is linear output for Discriminator.
predictions = tf.contrib.layers.linear(outputs, 1, scope=vs)
predictions = tf.transpose(predictions, [1, 0, 2])
if FLAGS.baseline_method == 'critic':
with tf.variable_scope('critic', reuse=reuse) as critic_scope:
values = tf.contrib.layers.linear(outputs, 1, scope=critic_scope)
values = tf.transpose(values, [1, 0, 2])
return tf.squeeze(predictions, axis=2), tf.squeeze(values, axis=2)
else:
return tf.squeeze(predictions, axis=2), None
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the bidirectional Discriminator graph."""
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_fwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
cell_bwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_fwd = cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
rnn_inputs = tf.unstack(rnn_inputs, axis=1)
with tf.variable_scope('rnn') as vs:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
cell_fwd, cell_bwd, rnn_inputs, state_fwd, state_bwd, scope=vs)
# Prediction is linear output for Discriminator.
predictions = tf.contrib.layers.linear(outputs, 1, scope=vs)
predictions = tf.transpose(predictions, [1, 0, 2])
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple CNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
del is_training
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension.")
with tf.variable_scope("gen/rnn", reuse=True):
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.gen_rnn_size])
dis_filter_sizes = [3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
with tf.variable_scope("dis", reuse=reuse):
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.dis_rnn_size])
cnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
# Create a convolution layer for each filter size
conv_outputs = []
for filter_size in dis_filter_sizes:
with tf.variable_scope("conv-%s" % filter_size):
# Convolution Layer
filter_shape = [
filter_size, hparams.dis_rnn_size, hparams.dis_num_filters
]
W = tf.get_variable(
name="W", initializer=tf.truncated_normal(filter_shape, stddev=0.1))
b = tf.get_variable(
name="b",
initializer=tf.constant(0.1, shape=[hparams.dis_num_filters]))
conv = tf.nn.conv1d(
cnn_inputs, W, stride=1, padding="SAME", name="conv")
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
conv_outputs.append(h)
# Combine all the pooled features
dis_num_filters_total = hparams.dis_num_filters * len(dis_filter_sizes)
h_conv = tf.concat(conv_outputs, axis=2)
h_conv_flat = tf.reshape(h_conv, [-1, dis_num_filters_total])
# Add dropout
with tf.variable_scope("dropout"):
h_drop = tf.nn.dropout(h_conv_flat, FLAGS.keep_prob)
with tf.variable_scope("fully_connected"):
fc = tf.contrib.layers.fully_connected(
h_drop, num_outputs=dis_num_filters_total / 2)
# Final (unnormalized) scores and predictions
with tf.variable_scope("output"):
W = tf.get_variable(
"W",
shape=[dis_num_filters_total / 2, 1],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(name="b", initializer=tf.constant(0.1, shape=[1]))
predictions = tf.nn.xw_plus_b(fc, W, b, name="predictions")
predictions = tf.reshape(
predictions, shape=[FLAGS.batch_size, FLAGS.sequence_length])
return predictions
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critic model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def critic_seq2seq_vd_derivative(hparams, sequence, is_training, reuse=None):
"""Define the Critic graph which is derived from the seq2seq_vd
Discriminator. This will be initialized with the same parameters as the
language model and will share the forward RNN components with the
Discriminator. This estimates the V(s_t), where the state
s_t = x_0,...,x_t-1.
"""
assert FLAGS.discriminator_model == 'seq2seq_vd'
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
else:
with tf.variable_scope('dis/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
with tf.variable_scope(
'dis/decoder/rnn/multi_rnn_cell', reuse=True) as dis_scope:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=True)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_critic = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
with tf.variable_scope('critic', reuse=reuse):
state_dis = cell_critic.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
with tf.variable_scope('rnn') as vs:
values = []
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
if t == 0:
rnn_in = tf.zeros_like(rnn_inputs[:, 0])
else:
rnn_in = rnn_inputs[:, t - 1]
rnn_out, state_dis = cell_critic(rnn_in, state_dis, scope=dis_scope)
if is_training:
rnn_out *= output_mask
# Prediction is linear output for Discriminator.
value = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
values.append(value)
values = tf.stack(values, axis=1)
return tf.squeeze(values, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import Counter
# Dependency imports
import numpy as np
from scipy.special import expit
import tensorflow as tf
from model_utils import helper
from model_utils import n_gram
FLAGS = tf.app.flags.FLAGS
def print_and_log_losses(log, step, is_present_rate, avg_dis_loss,
avg_gen_loss):
"""Prints and logs losses to the log file.
Args:
log: GFile for logs.
step: Global step.
is_present_rate: Current masking rate.
avg_dis_loss: List of Discriminator losses.
avg_gen_loss: List of Generator losses.
"""
print('global_step: %d' % step)
print(' is_present_rate: %.3f' % is_present_rate)
print(' D train loss: %.5f' % np.mean(avg_dis_loss))
print(' G train loss: %.5f' % np.mean(avg_gen_loss))
log.write('\nglobal_step: %d\n' % step)
log.write((' is_present_rate: %.3f\n' % is_present_rate))
log.write(' D train loss: %.5f\n' % np.mean(avg_dis_loss))
log.write(' G train loss: %.5f\n' % np.mean(avg_gen_loss))
def print_and_log(log, id_to_word, sequence_eval, max_num_to_print=5):
"""Helper function for printing and logging evaluated sequences."""
indices_arr = np.asarray(sequence_eval)
samples = helper.convert_to_human_readable(id_to_word, indices_arr,
max_num_to_print)
for i, sample in enumerate(samples):
print('Sample', i, '. ', sample)
log.write('\nSample ' + str(i) + '. ' + sample)
log.write('\n')
print('\n')
log.flush()
return samples
def zip_seq_pred_crossent(id_to_word, sequences, predictions, cross_entropy):
"""Zip together the sequences, predictions, cross entropy."""
indices = np.asarray(sequences)
batch_of_metrics = []
for ind_batch, pred_batch, crossent_batch in zip(indices, predictions,
cross_entropy):
metrics = []
for index, pred, crossent in zip(ind_batch, pred_batch, crossent_batch):
metrics.append([str(id_to_word[index]), pred, crossent])
batch_of_metrics.append(metrics)
return batch_of_metrics
def zip_metrics(indices, *args):
"""Zip together the indices matrices with the provided metrics matrices."""
batch_of_metrics = []
for metrics_batch in zip(indices, *args):
metrics = []
for m in zip(*metrics_batch):
metrics.append(m)
batch_of_metrics.append(metrics)
return batch_of_metrics
def print_formatted(present, id_to_word, log, batch_of_tuples):
"""Print and log metrics."""
num_cols = len(batch_of_tuples[0][0])
repeat_float_format = '{:<12.3f} '
repeat_str_format = '{:<13}'
format_str = ''.join(
['[{:<1}] {:<20}',
str(repeat_float_format * (num_cols - 1))])
# TODO(liamfedus): Generalize the logging. This is sloppy.
header_format_str = ''.join(
['[{:<1}] {:<20}',
str(repeat_str_format * (num_cols - 1))])
header_str = header_format_str.format('p', 'Word', 'p(real)', 'log-perp',
'log(p(a))', 'r', 'R=V*(s)', 'b=V(s)',
'A(a,s)')
for i, batch in enumerate(batch_of_tuples):
print(' Sample: %d' % i)
log.write(' Sample %d.\n' % i)
print(' ', header_str)
log.write(' ' + str(header_str) + '\n')
for j, t in enumerate(batch):
t = list(t)
t[0] = id_to_word[t[0]]
buffer_str = format_str.format(int(present[i][j]), *t)
print(' ', buffer_str)
log.write(' ' + str(buffer_str) + '\n')
log.flush()
def generate_RL_logs(sess, model, log, id_to_word, feed):
"""Generate complete logs while running with REINFORCE."""
# Impute Sequences.
[
p,
fake_sequence_eval,
fake_predictions_eval,
_,
fake_cross_entropy_losses_eval,
_,
fake_log_probs_eval,
fake_rewards_eval,
fake_baselines_eval,
cumulative_rewards_eval,
fake_advantages_eval,
] = sess.run(
[
model.present,
model.fake_sequence,
model.fake_predictions,
model.real_predictions,
model.fake_cross_entropy_losses,
model.fake_logits,
model.fake_log_probs,
model.fake_rewards,
model.fake_baselines,
model.cumulative_rewards,
model.fake_advantages,
],
feed_dict=feed)
indices = np.asarray(fake_sequence_eval)
# Convert Discriminator linear layer to probability.
fake_prob_eval = expit(fake_predictions_eval)
# Add metrics.
fake_tuples = zip_metrics(indices, fake_prob_eval,
fake_cross_entropy_losses_eval, fake_log_probs_eval,
fake_rewards_eval, cumulative_rewards_eval,
fake_baselines_eval, fake_advantages_eval)
# real_tuples = zip_metrics(indices, )
# Print forward sequences.
tuples_to_print = fake_tuples[:FLAGS.max_num_to_print]
print_formatted(p, id_to_word, log, tuples_to_print)
print('Samples')
log.write('Samples\n')
samples = print_and_log(log, id_to_word, fake_sequence_eval,
FLAGS.max_num_to_print)
return samples
def generate_logs(sess, model, log, id_to_word, feed):
"""Impute Sequences using the model for a particular feed and send it to
logs."""
# Impute Sequences.
[
p, sequence_eval, fake_predictions_eval, fake_cross_entropy_losses_eval,
fake_logits_eval
] = sess.run(
[
model.present, model.fake_sequence, model.fake_predictions,
model.fake_cross_entropy_losses, model.fake_logits
],
feed_dict=feed)
# Convert Discriminator linear layer to probability.
fake_prob_eval = expit(fake_predictions_eval)
# Forward Masked Tuples.
fake_tuples = zip_seq_pred_crossent(id_to_word, sequence_eval, fake_prob_eval,
fake_cross_entropy_losses_eval)
tuples_to_print = fake_tuples[:FLAGS.max_num_to_print]
if FLAGS.print_verbose:
print('fake_logits_eval')
print(fake_logits_eval)
for i, batch in enumerate(tuples_to_print):
print(' Sample %d.' % i)
log.write(' Sample %d.\n' % i)
for j, pred in enumerate(batch):
buffer_str = ('[{:<1}] {:<20} {:<7.3f} {:<7.3f}').format(
int(p[i][j]), pred[0], pred[1], pred[2])
print(' ', buffer_str)
log.write(' ' + str(buffer_str) + '\n')
log.flush()
print('Samples')
log.write('Samples\n')
samples = print_and_log(log, id_to_word, sequence_eval,
FLAGS.max_num_to_print)
return samples
def create_merged_ngram_dictionaries(indices, n):
"""Generate a single dictionary for the full batch.
Args:
indices: List of lists of indices.
n: Degree of n-grams.
Returns:
Dictionary of hashed(n-gram tuples) to counts in the batch of indices.
"""
ngram_dicts = []
for ind in indices:
ngrams = n_gram.find_all_ngrams(ind, n=n)
ngram_counts = n_gram.construct_ngrams_dict(ngrams)
ngram_dicts.append(ngram_counts)
merged_gen_dict = Counter()
for ngram_dict in ngram_dicts:
merged_gen_dict += Counter(ngram_dict)
return merged_gen_dict
def sequence_ngram_evaluation(sess, sequence, log, feed, data_ngram_count, n):
"""Calculates the percent of ngrams produced in the sequence is present in
data_ngram_count.
Args:
sess: tf.Session.
sequence: Sequence Tensor from the MaskGAN model.
log: gFile log.
feed: Feed to evaluate.
data_ngram_count: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
Returns:
avg_percent_captured: Percent of produced ngrams that appear in the
data_ngram_count.
"""
del log
# Impute sequence.
[sequence_eval] = sess.run([sequence], feed_dict=feed)
indices = sequence_eval
# Retrieve the counts across the batch of indices.
gen_ngram_counts = create_merged_ngram_dictionaries(
indices, n=n)
return n_gram.percent_unique_ngrams_in_train(data_ngram_count,
gen_ngram_counts)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple FNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
del is_training
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension.")
with tf.variable_scope("gen/rnn", reuse=True):
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope("dis", reuse=reuse):
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.dis_rnn_size])
embeddings = tf.nn.embedding_lookup(embedding, sequence)
# Input matrices.
W = tf.get_variable(
"W",
initializer=tf.truncated_normal(
shape=[3 * hparams.dis_embedding_dim, hparams.dis_hidden_dim],
stddev=0.1))
b = tf.get_variable(
"b", initializer=tf.constant(0.1, shape=[hparams.dis_hidden_dim]))
# Output matrices.
W_out = tf.get_variable(
"W_out",
initializer=tf.truncated_normal(
shape=[hparams.dis_hidden_dim, 1], stddev=0.1))
b_out = tf.get_variable("b_out", initializer=tf.constant(0.1, shape=[1]))
predictions = []
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
inp = embeddings[:, t]
if t > 0:
past_inp = tf.unstack(embeddings[:, 0:t], axis=1)
avg_past_inp = tf.add_n(past_inp) / len(past_inp)
else:
avg_past_inp = tf.zeros_like(inp)
if t < FLAGS.sequence_length:
future_inp = tf.unstack(embeddings[:, t:], axis=1)
avg_future_inp = tf.add_n(future_inp) / len(future_inp)
else:
avg_future_inp = tf.zeros_like(inp)
# Cumulative input.
concat_inp = tf.concat([avg_past_inp, inp, avg_future_inp], axis=1)
# Hidden activations.
hidden = tf.nn.relu(tf.nn.xw_plus_b(concat_inp, W, b, name="scores"))
# Add dropout
with tf.variable_scope("dropout"):
hidden = tf.nn.dropout(hidden, FLAGS.keep_prob)
# Output.
output = tf.nn.xw_plus_b(hidden, W_out, b_out, name="output")
predictions.append(output)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# ZoneoutWrapper.
from regularization import zoneout
FLAGS = tf.app.flags.FLAGS
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.warning(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
init_scale = 0.05
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.gen_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.get_variable('softmax_w',
[hparams.gen_rnn_size, FLAGS.vocab_size])
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the target at the last time-step was present, read in the real.
# If the target at the last time-step was not present, read in the fake.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# Use teacher forcing.
if (is_training and
FLAGS.gen_training_strategy == 'cross_entropy') or is_validating:
rnn_inp = real_rnn_inp
else:
# Note that targets_t-1 == inputs_(t)
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
# Fake sample.
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the target.
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Append to lists.
sequence.append(output)
logits.append(logit)
log_probs.append(log_prob)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph.
Args:
hparams: Hyperparameters for the MaskGAN.
FLAGS: Current flags.
sequence: [FLAGS.batch_size, FLAGS.sequence_length]
is_training:
reuse
Returns:
predictions:
"""
tf.logging.warning(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.dis_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
# NAS Code..
from nas_utils import configs
from nas_utils import custom_cell
from nas_utils import variational_dropout
FLAGS = tf.app.flags.FLAGS
def get_config():
return configs.AlienConfig2()
LSTMTuple = collections.namedtuple('LSTMTuple', ['c', 'h'])
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.info(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
config = get_config()
config.keep_prob = [hparams.gen_nas_keep_prob_0, hparams.gen_nas_keep_prob_1]
configs.print_config(config)
init_scale = config.init_scale
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, _, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
output_mask = None
cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.matrix_transpose(embedding)
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
if is_training:
state_gen = list(state_gen)
for layer_num, per_layer_state in enumerate(state_gen):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state_gen[layer_num] = per_layer_state
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if is_training:
rnn_out = output_mask * rnn_out
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
tf.logging.info(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
config = get_config()
config.keep_prob = [hparams.dis_nas_keep_prob_0, hparams.dis_nas_keep_prob_1]
configs.print_config(config)
with tf.variable_scope('dis', reuse=reuse):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, _, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
output_mask = None
cell_dis = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
if is_training:
state_dis = list(state_dis)
for layer_num, per_layer_state in enumerate(state_dis):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state_dis[layer_num] = per_layer_state
# RNN.
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
if is_training:
rnn_out = output_mask * rnn_out
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None):
"""Define the Discriminator graph."""
tf.logging.info(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
if initial_state:
state_dis = [[tf.identity(x) for x in inner_initial_state]
for inner_initial_state in initial_state]
else:
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
with tf.variable_scope('rnn') as vs:
predictions, rnn_outs = [], []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
if is_training:
rnn_out *= output_mask
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
rnn_outs.append(rnn_out)
predictions = tf.stack(predictions, axis=1)
if FLAGS.baseline_method == 'critic':
with tf.variable_scope('critic', reuse=reuse) as critic_scope:
rnn_outs = tf.stack(rnn_outs, axis=1)
values = tf.contrib.layers.linear(rnn_outs, 1, scope=critic_scope)
return tf.squeeze(predictions, axis=2), tf.squeeze(values, axis=2)
else:
return tf.squeeze(predictions, axis=2), None
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.warning(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
init_scale = 0.05
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(hparams.gen_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.get_variable('softmax_w',
[hparams.gen_rnn_size, FLAGS.vocab_size])
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
fake = None
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
tf.logging.warning(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Rollout RNN model definitions which call rnn_zaremba code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
from losses import losses
from model_utils import helper
from model_utils import model_construction
from model_utils import model_losses
from model_utils import model_optimization
FLAGS = tf.app.flags.FLAGS
def create_rollout_MaskGAN(hparams, is_training):
"""Create the MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
is_training: Boolean indicating operational mode (train/inference).
evaluated with a teacher forcing regime.
Return:
model: Namedtuple for specifying the MaskGAN."""
global_step = tf.Variable(0, name='global_step', trainable=False)
new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate')
learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False)
learning_rate_update = tf.assign(learning_rate, new_learning_rate)
new_rate = tf.placeholder(tf.float32, [], name='new_rate')
percent_real_var = tf.Variable(0.0, trainable=False)
percent_real_update = tf.assign(percent_real_var, new_rate)
## Placeholders.
inputs = tf.placeholder(
tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length])
present = tf.placeholder(
tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length])
inv_present = tf.placeholder(
tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length])
## Rollout Generator.
fwd_gen_rollouts = rollout_generator(
hparams, inputs, present, is_training=is_training, is_validating=False)
inv_gen_rollouts = rollout_generator(
hparams,
inputs,
inv_present,
is_training=is_training,
is_validating=False,
reuse=True)
## Rollout Discriminator.
fwd_dis_rollouts = rollout_discriminator(
hparams, fwd_gen_rollouts, is_training=is_training)
inv_dis_rollouts = rollout_discriminator(
hparams, inv_gen_rollouts, is_training=is_training, reuse=True)
## Discriminator Loss.
[dis_loss, dis_loss_pred, dis_loss_inv_pred] = rollout_discriminator_loss(
fwd_dis_rollouts, present, inv_dis_rollouts, inv_present)
## Average log-perplexity for only missing words. However, to do this,
# the logits are still computed using teacher forcing, that is, the ground
# truth tokens are fed in at each time point to be valid.
# TODO(liamfedus): Fix the naming convention.
with tf.variable_scope('gen_rollout'):
_, fwd_eval_logits, _ = model_construction.create_generator(
hparams,
inputs,
present,
is_training=False,
is_validating=True,
reuse=True)
avg_log_perplexity = model_losses.calculate_log_perplexity(
fwd_eval_logits, inputs, present)
## Generator Loss.
# 1. Cross Entropy losses on missing tokens.
[fwd_cross_entropy_losses,
inv_cross_entropy_losses] = rollout_masked_cross_entropy_loss(
inputs, present, inv_present, fwd_gen_rollouts, inv_gen_rollouts)
# 2. GAN losses on missing tokens.
[fwd_RL_loss,
fwd_RL_statistics, fwd_averages_op] = rollout_reinforce_objective(
hparams, fwd_gen_rollouts, fwd_dis_rollouts, present)
[inv_RL_loss,
inv_RL_statistics, inv_averages_op] = rollout_reinforce_objective(
hparams, inv_gen_rollouts, inv_dis_rollouts, inv_present)
# TODO(liamfedus): Generalize this to use all logs.
[fwd_sequence, fwd_logits, fwd_log_probs] = fwd_gen_rollouts[-1]
[inv_sequence, inv_logits, inv_log_probs] = inv_gen_rollouts[-1]
# TODO(liamfedus): Generalize this to use all logs.
fwd_predictions = fwd_dis_rollouts[-1]
inv_predictions = inv_dis_rollouts[-1]
# TODO(liamfedus): Generalize this to use all logs.
[fwd_log_probs, fwd_rewards, fwd_advantages,
fwd_baselines] = fwd_RL_statistics[-1]
[inv_log_probs, inv_rewards, inv_advantages,
inv_baselines] = inv_RL_statistics[-1]
## Pre-training.
if FLAGS.gen_pretrain_steps:
# TODO(liamfedus): Rewrite this.
fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses)
gen_pretrain_op = model_optimization.create_gen_pretrain_op(
hparams, fwd_cross_entropy_loss, global_step)
else:
gen_pretrain_op = tf.no_op('gen_pretrain_no_op')
if FLAGS.dis_pretrain_steps:
dis_pretrain_op = model_optimization.create_dis_pretrain_op(
hparams, dis_loss, global_step)
else:
dis_pretrain_op = tf.no_op('dis_pretrain_no_op')
## Generator Train Op.
# 1. Cross-Entropy.
if FLAGS.gen_training_strategy == 'cross_entropy':
gen_loss = tf.reduce_mean(
fwd_cross_entropy_losses + inv_cross_entropy_losses) / 2.
[gen_train_op, gen_grads,
gen_vars] = model_optimization.create_gen_train_op(
hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE')
# 2. GAN (REINFORCE)
elif FLAGS.gen_training_strategy == 'reinforce':
gen_loss = (fwd_RL_loss + inv_RL_loss) / 2.
[gen_train_op, gen_grads,
gen_vars] = model_optimization.create_reinforce_gen_train_op(
hparams, learning_rate, gen_loss, fwd_averages_op, inv_averages_op,
global_step)
else:
raise NotImplementedError
## Discriminator Train Op.
dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op(
hparams, dis_loss, global_step)
## Summaries.
with tf.name_scope('general'):
tf.summary.scalar('percent_real', percent_real_var)
tf.summary.scalar('learning_rate', learning_rate)
with tf.name_scope('generator_losses'):
tf.summary.scalar('gen_loss', tf.reduce_mean(gen_loss))
tf.summary.scalar('gen_loss_fwd_cross_entropy',
tf.reduce_mean(fwd_cross_entropy_losses))
tf.summary.scalar('gen_loss_inv_cross_entropy',
tf.reduce_mean(inv_cross_entropy_losses))
with tf.name_scope('REINFORCE'):
with tf.name_scope('objective'):
tf.summary.scalar('fwd_RL_loss', tf.reduce_mean(fwd_RL_loss))
tf.summary.scalar('inv_RL_loss', tf.reduce_mean(inv_RL_loss))
with tf.name_scope('rewards'):
helper.variable_summaries(fwd_rewards, 'fwd_rewards')
helper.variable_summaries(inv_rewards, 'inv_rewards')
with tf.name_scope('advantages'):
helper.variable_summaries(fwd_advantages, 'fwd_advantages')
helper.variable_summaries(inv_advantages, 'inv_advantages')
with tf.name_scope('baselines'):
helper.variable_summaries(fwd_baselines, 'fwd_baselines')
helper.variable_summaries(inv_baselines, 'inv_baselines')
with tf.name_scope('log_probs'):
helper.variable_summaries(fwd_log_probs, 'fwd_log_probs')
helper.variable_summaries(inv_log_probs, 'inv_log_probs')
with tf.name_scope('discriminator_losses'):
tf.summary.scalar('dis_loss', dis_loss)
tf.summary.scalar('dis_loss_fwd_sequence', dis_loss_pred)
tf.summary.scalar('dis_loss_inv_sequence', dis_loss_inv_pred)
with tf.name_scope('logits'):
helper.variable_summaries(fwd_logits, 'fwd_logits')
helper.variable_summaries(inv_logits, 'inv_logits')
for v, g in zip(gen_vars, gen_grads):
helper.variable_summaries(v, v.op.name)
helper.variable_summaries(g, 'grad/' + v.op.name)
for v, g in zip(dis_vars, dis_grads):
helper.variable_summaries(v, v.op.name)
helper.variable_summaries(g, 'grad/' + v.op.name)
merge_summaries_op = tf.summary.merge_all()
# Model saver.
saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5)
# Named tuple that captures elements of the MaskGAN model.
Model = collections.namedtuple('Model', [
'inputs', 'present', 'inv_present', 'percent_real_update', 'new_rate',
'fwd_sequence', 'fwd_logits', 'fwd_rewards', 'fwd_advantages',
'fwd_log_probs', 'fwd_predictions', 'fwd_cross_entropy_losses',
'inv_sequence', 'inv_logits', 'inv_rewards', 'inv_advantages',
'inv_log_probs', 'inv_predictions', 'inv_cross_entropy_losses',
'avg_log_perplexity', 'dis_loss', 'gen_loss', 'dis_train_op',
'gen_train_op', 'gen_pretrain_op', 'dis_pretrain_op',
'merge_summaries_op', 'global_step', 'new_learning_rate',
'learning_rate_update', 'saver'
])
model = Model(
inputs, present, inv_present, percent_real_update, new_rate, fwd_sequence,
fwd_logits, fwd_rewards, fwd_advantages, fwd_log_probs, fwd_predictions,
fwd_cross_entropy_losses, inv_sequence, inv_logits, inv_rewards,
inv_advantages, inv_log_probs, inv_predictions, inv_cross_entropy_losses,
avg_log_perplexity, dis_loss, gen_loss, dis_train_op, gen_train_op,
gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step,
new_learning_rate, learning_rate_update, saver)
return model
def rollout_generator(hparams,
inputs,
input_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph which does rollouts.
G will now impute tokens that have been masked from the input seqeunce.
"""
rollouts = []
with tf.variable_scope('gen_rollout'):
for n in xrange(FLAGS.num_rollouts):
if n > 0:
# TODO(liamfedus): Why is it necessary here to manually set reuse?
reuse = True
tf.get_variable_scope().reuse_variables()
[sequence, logits, log_probs] = model_construction.create_generator(
hparams,
inputs,
input_present,
is_training,
is_validating,
reuse=reuse)
rollouts.append([sequence, logits, log_probs])
# Length assertion.
assert len(rollouts) == FLAGS.num_rollouts
return rollouts
def rollout_discriminator(hparams, gen_rollouts, is_training, reuse=None):
"""Define the Discriminator graph which does rollouts.
G will now impute tokens that have been masked from the input seqeunce.
"""
rollout_predictions = []
with tf.variable_scope('dis_rollout'):
for n, rollout in enumerate(gen_rollouts):
if n > 0:
# TODO(liamfedus): Why is it necessary here to manually set reuse?
reuse = True
tf.get_variable_scope().reuse_variables()
[sequence, _, _] = rollout
predictions = model_construction.create_discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
# Predictions for each rollout.
rollout_predictions.append(predictions)
# Length assertion.
assert len(rollout_predictions) == FLAGS.num_rollouts
return rollout_predictions
def rollout_reinforce_objective(hparams, gen_rollouts, dis_rollouts, present):
cumulative_gen_objective = 0.
cumulative_averages_op = []
cumulative_statistics = []
assert len(gen_rollouts) == len(dis_rollouts)
for gen_rollout, dis_rollout in zip(gen_rollouts, dis_rollouts):
[_, _, log_probs] = gen_rollout
dis_predictions = dis_rollout
[
final_gen_objective, log_probs, rewards, advantages, baselines,
maintain_averages_op
] = model_losses.calculate_reinforce_objective(hparams, log_probs,
dis_predictions, present)
# Accumulate results.
cumulative_gen_objective += final_gen_objective
cumulative_averages_op.append(maintain_averages_op)
cumulative_statistics.append([log_probs, rewards, advantages, baselines])
# Group all the averaging operations.
cumulative_averages_op = tf.group(*cumulative_averages_op)
cumulative_gen_objective /= FLAGS.num_rollouts
[log_probs, rewards, advantages, baselines] = cumulative_statistics[-1]
# Length assertion.
assert len(cumulative_statistics) == FLAGS.num_rollouts
return [
cumulative_gen_objective, cumulative_statistics, cumulative_averages_op
]
def rollout_masked_cross_entropy_loss(inputs, present, inv_present,
fwd_rollouts, inv_rollouts):
cumulative_fwd_cross_entropy_losses = tf.zeros(
shape=[FLAGS.batch_size, FLAGS.sequence_length])
cumulative_inv_cross_entropy_losses = tf.zeros(
shape=[FLAGS.batch_size, FLAGS.sequence_length])
for fwd_rollout, inv_rollout in zip(fwd_rollouts, inv_rollouts):
[_, fwd_logits, _] = fwd_rollout
[_, inv_logits, _] = inv_rollout
[fwd_cross_entropy_losses,
inv_cross_entropy_losses] = model_losses.create_masked_cross_entropy_loss(
inputs, present, inv_present, fwd_logits, inv_logits)
cumulative_fwd_cross_entropy_losses = tf.add(
cumulative_fwd_cross_entropy_losses, fwd_cross_entropy_losses)
cumulative_inv_cross_entropy_losses = tf.add(
cumulative_inv_cross_entropy_losses, inv_cross_entropy_losses)
return [
cumulative_fwd_cross_entropy_losses, cumulative_inv_cross_entropy_losses
]
def rollout_discriminator_loss(fwd_rollouts, present, inv_rollouts,
inv_present):
dis_loss = 0
dis_loss_pred = 0
dis_loss_inv_pred = 0
for fwd_predictions, inv_predictions in zip(fwd_rollouts, inv_rollouts):
dis_loss_pred += losses.discriminator_loss(fwd_predictions, present)
dis_loss_inv_pred += losses.discriminator_loss(inv_predictions, inv_present)
dis_loss_pred /= FLAGS.num_rollouts
dis_loss_inv_pred /= FLAGS.num_rollouts
dis_loss = (dis_loss_pred + dis_loss_inv_pred) / 2.
return [dis_loss, dis_loss_pred, dis_loss_inv_pred]
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from models import attention_utils
# ZoneoutWrapper.
from regularization import zoneout
FLAGS = tf.app.flags.FLAGS
def transform_input_with_is_missing_token(inputs, targets_present):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing = tf.constant(
FLAGS.vocab_size,
dtype=tf.int32,
shape=[FLAGS.batch_size, FLAGS.sequence_length])
# The 0th input will always be present to MaskGAN.
zeroth_input_present = tf.constant(True, tf.bool, shape=[FLAGS.batch_size, 1])
# Input present mask.
inputs_present = tf.concat(
[zeroth_input_present, targets_present[:, :-1]], axis=1)
transformed_input = tf.where(inputs_present, inputs, input_missing)
return transformed_input
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None):
"""Define the Encoder graph."""
# We will use the same variable from the decoder.
if FLAGS.seq2seq_share_embedding:
with tf.variable_scope('decoder/rnn'):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('encoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.gen_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell.zero_state(FLAGS.batch_size, tf.float32)
# Add a missing token for inputs not present.
real_inputs = inputs
masked_inputs = transform_input_with_is_missing_token(
inputs, targets_present)
with tf.variable_scope('rnn'):
hidden_states = []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size + 1, hparams.gen_rnn_size])
real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs)
masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs)
state = initial_state
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_inp = masked_rnn_inputs[:, t]
rnn_out, state = cell(rnn_inp, state)
hidden_states.append(rnn_out)
final_masked_state = state
hidden_states = tf.stack(hidden_states, axis=1)
# Produce the RNN state had the model operated only
# over real data.
real_state = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
# RNN.
rnn_inp = real_rnn_inputs[:, t]
rnn_out, real_state = cell(rnn_inp, real_state)
final_state = real_state
return (hidden_states, final_masked_state), initial_state, final_state
def gen_decoder(hparams,
inputs,
targets,
targets_present,
encoding_state,
is_training,
is_validating,
reuse=None):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size = hparams.gen_rnn_size
with tf.variable_scope('decoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
gen_decoder_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
# Hidden encoder states.
hidden_vector_encodings = encoding_state[0]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen = encoding_state[1]
if FLAGS.attention_option is not None:
(attention_keys, attention_values, _,
attention_construct_fn) = attention_utils.prepare_attention(
hidden_vector_encodings,
FLAGS.attention_option,
num_units=gen_decoder_rnn_size,
reuse=reuse)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, gen_decoder_rnn_size])
softmax_w = tf.get_variable('softmax_w',
[gen_decoder_rnn_size, FLAGS.vocab_size])
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the Decoder.
if t == 0:
# Always provide the real input at t = 0.
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if FLAGS.attention_option is not None:
rnn_out = attention_construct_fn(rnn_out, attention_keys,
attention_values)
# # TODO(liamfedus): Assert not "monotonic" attention_type.
# # TODO(liamfedus): FLAGS.attention_type.
# context_state = revised_attention_utils._empty_state()
# rnn_out, context_state = attention_construct_fn(
# rnn_out, attention_keys, attention_values, context_state, t)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1))
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph."""
with tf.variable_scope('gen', reuse=reuse):
encoder_states, initial_state, final_state = gen_encoder(
hparams, inputs, targets_present, is_training=is_training, reuse=reuse)
stacked_sequence, stacked_logits, stacked_log_probs = gen_decoder(
hparams,
inputs,
targets,
targets_present,
encoder_states,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
return (stacked_sequence, stacked_logits, stacked_log_probs, initial_state,
final_state)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
from models import attention_utils
# NAS Code..
from nas_utils import configs
from nas_utils import custom_cell
from nas_utils import variational_dropout
FLAGS = tf.app.flags.FLAGS
def get_config():
return configs.AlienConfig2()
LSTMTuple = collections.namedtuple('LSTMTuple', ['c', 'h'])
def transform_input_with_is_missing_token(inputs, targets_present):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing = tf.constant(
FLAGS.vocab_size,
dtype=tf.int32,
shape=[FLAGS.batch_size, FLAGS.sequence_length])
# The 0th input will always be present to MaskGAN.
zeroth_input_present = tf.constant(True, tf.bool, shape=[FLAGS.batch_size, 1])
# Input present mask.
inputs_present = tf.concat(
[zeroth_input_present, targets_present[:, :-1]], axis=1)
transformed_input = tf.where(inputs_present, inputs, input_missing)
return transformed_input
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
config = get_config()
configs.print_config(config)
# We will use the same variable from the decoder.
if FLAGS.seq2seq_share_embedding:
with tf.variable_scope('decoder/rnn'):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('encoder', reuse=reuse):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, h2i_masks, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
h2i_masks, output_mask = None, None
cell = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
initial_state = cell.zero_state(FLAGS.batch_size, tf.float32)
# Add a missing token for inputs not present.
real_inputs = inputs
masked_inputs = transform_input_with_is_missing_token(
inputs, targets_present)
with tf.variable_scope('rnn'):
hidden_states = []
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
if not FLAGS.seq2seq_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
missing_embedding = tf.get_variable('missing_embedding',
[1, hparams.gen_rnn_size])
embedding = tf.concat([embedding, missing_embedding], axis=0)
real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs)
masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs)
if is_training and FLAGS.keep_prob < 1:
masked_rnn_inputs = tf.nn.dropout(masked_rnn_inputs, FLAGS.keep_prob)
state = initial_state
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_inp = masked_rnn_inputs[:, t]
if is_training:
state = list(state)
for layer_num, per_layer_state in enumerate(state):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state[layer_num] = per_layer_state
rnn_out, state = cell(rnn_inp, state, h2i_masks)
if is_training:
rnn_out = output_mask * rnn_out
hidden_states.append(rnn_out)
final_masked_state = state
hidden_states = tf.stack(hidden_states, axis=1)
# Produce the RNN state had the model operated only
# over real data.
real_state = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
# RNN.
rnn_inp = real_rnn_inputs[:, t]
rnn_out, real_state = cell(rnn_inp, real_state)
final_state = real_state
return (hidden_states, final_masked_state), initial_state, final_state
def gen_decoder(hparams,
inputs,
targets,
targets_present,
encoding_state,
is_training,
is_validating,
reuse=None):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
config = get_config()
gen_decoder_rnn_size = hparams.gen_rnn_size
if FLAGS.seq2seq_share_embedding:
with tf.variable_scope('decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, gen_decoder_rnn_size])
with tf.variable_scope('decoder', reuse=reuse):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, _, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
output_mask = None
cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
# Hidden encoder states.
hidden_vector_encodings = encoding_state[0]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen = encoding_state[1]
if FLAGS.attention_option is not None:
(attention_keys, attention_values, _,
attention_construct_fn) = attention_utils.prepare_attention(
hidden_vector_encodings,
FLAGS.attention_option,
num_units=gen_decoder_rnn_size,
reuse=reuse)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
if not FLAGS.seq2seq_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, gen_decoder_rnn_size])
softmax_w = tf.matrix_transpose(embedding)
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the Decoder.
if t == 0:
# Always provide the real input at t = 0.
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
if is_training:
state_gen = list(state_gen)
for layer_num, per_layer_state in enumerate(state_gen):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state_gen[layer_num] = per_layer_state
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if is_training:
rnn_out = output_mask * rnn_out
if FLAGS.attention_option is not None:
rnn_out = attention_construct_fn(rnn_out, attention_keys,
attention_values)
# # TODO(liamfedus): Assert not "monotonic" attention_type.
# # TODO(liamfedus): FLAGS.attention_type.
# context_state = revised_attention_utils._empty_state()
# rnn_out, context_state = attention_construct_fn(
# rnn_out, attention_keys, attention_values, context_state, t)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1))
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph."""
with tf.variable_scope('gen', reuse=reuse):
encoder_states, initial_state, final_state = gen_encoder(
hparams, inputs, targets_present, is_training=is_training, reuse=reuse)
stacked_sequence, stacked_logits, stacked_log_probs = gen_decoder(
hparams,
inputs,
targets,
targets_present,
encoder_states,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
return (stacked_sequence, stacked_logits, stacked_log_probs, initial_state,
final_state)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from models import attention_utils
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def transform_input_with_is_missing_token(inputs, targets_present):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
which computes,
inputs_present = [1, 1, 0, 1]
and outputs,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing = tf.constant(
FLAGS.vocab_size,
dtype=tf.int32,
shape=[FLAGS.batch_size, FLAGS.sequence_length])
# The 0th input will always be present to MaskGAN.
zeroth_input_present = tf.constant(True, tf.bool, shape=[FLAGS.batch_size, 1])
# Input present mask.
inputs_present = tf.concat(
[zeroth_input_present, targets_present[:, :-1]], axis=1)
transformed_input = tf.where(inputs_present, inputs, input_missing)
return transformed_input
# TODO(adai): IMDB labels placeholder to encoder.
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
# We will use the same variable from the decoder.
if FLAGS.seq2seq_share_embedding:
with tf.variable_scope('decoder/rnn'):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('encoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.gen_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.gen_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size,
hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell.zero_state(FLAGS.batch_size, tf.float32)
# Add a missing token for inputs not present.
real_inputs = inputs
masked_inputs = transform_input_with_is_missing_token(
inputs, targets_present)
with tf.variable_scope('rnn') as scope:
hidden_states = []
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
if not FLAGS.seq2seq_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
missing_embedding = tf.get_variable('missing_embedding',
[1, hparams.gen_rnn_size])
embedding = tf.concat([embedding, missing_embedding], axis=0)
# TODO(adai): Perhaps append IMDB labels placeholder to input at
# each time point.
real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs)
masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs)
state = initial_state
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(
tf.stack([FLAGS.batch_size, 1, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.gen_vd_keep_prob, hparams.gen_rnn_size)
hidden_states, state = tf.nn.dynamic_rnn(
cell, masked_rnn_inputs, initial_state=state, scope=scope)
if is_training:
hidden_states *= output_mask
final_masked_state = state
# Produce the RNN state had the model operated only
# over real data.
real_state = initial_state
_, real_state = tf.nn.dynamic_rnn(
cell, real_rnn_inputs, initial_state=real_state, scope=scope)
final_state = real_state
return (hidden_states, final_masked_state), initial_state, final_state
# TODO(adai): IMDB labels placeholder to encoder.
def gen_encoder_cnn(hparams, inputs, targets_present, is_training, reuse=None):
"""Define the CNN Encoder graph."""
del reuse
sequence = transform_input_with_is_missing_token(inputs, targets_present)
# TODO(liamfedus): Make this a hyperparameter.
dis_filter_sizes = [3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
# Keeping track of l2 regularization loss (optional)
# l2_loss = tf.constant(0.0)
with tf.variable_scope('encoder', reuse=True):
with tf.variable_scope('rnn'):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
cnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
# Create a convolution layer for each filter size
conv_outputs = []
for filter_size in dis_filter_sizes:
with tf.variable_scope('conv-%s' % filter_size):
# Convolution Layer
filter_shape = [
filter_size, hparams.gen_rnn_size, hparams.dis_num_filters
]
W = tf.get_variable(
name='W', initializer=tf.truncated_normal(filter_shape, stddev=0.1))
b = tf.get_variable(
name='b',
initializer=tf.constant(0.1, shape=[hparams.dis_num_filters]))
conv = tf.nn.conv1d(cnn_inputs, W, stride=1, padding='SAME', name='conv')
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')
conv_outputs.append(h)
# Combine all the pooled features
dis_num_filters_total = hparams.dis_num_filters * len(dis_filter_sizes)
h_conv = tf.concat(conv_outputs, axis=2)
h_conv_flat = tf.reshape(h_conv, [-1, dis_num_filters_total])
# Add dropout
if is_training:
with tf.variable_scope('dropout'):
h_conv_flat = tf.nn.dropout(h_conv_flat, hparams.gen_vd_keep_prob)
# Final (unnormalized) scores and predictions
with tf.variable_scope('output'):
W = tf.get_variable(
'W',
shape=[dis_num_filters_total, hparams.gen_rnn_size],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(
name='b', initializer=tf.constant(0.1, shape=[hparams.gen_rnn_size]))
# l2_loss += tf.nn.l2_loss(W)
# l2_loss += tf.nn.l2_loss(b)
predictions = tf.nn.xw_plus_b(h_conv_flat, W, b, name='predictions')
predictions = tf.reshape(
predictions,
shape=[FLAGS.batch_size, FLAGS.sequence_length, hparams.gen_rnn_size])
final_state = tf.reduce_mean(predictions, 1)
return predictions, (final_state, final_state)
# TODO(adai): IMDB labels placeholder to decoder.
def gen_decoder(hparams,
inputs,
targets,
targets_present,
encoding_state,
is_training,
is_validating,
reuse=None):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size = hparams.gen_rnn_size
targets = tf.Print(targets, [targets], message='targets', summarize=50)
if FLAGS.seq2seq_share_embedding:
with tf.variable_scope('decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('decoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
gen_decoder_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.gen_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size,
hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
# Hidden encoder states.
hidden_vector_encodings = encoding_state[0]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen = encoding_state[1]
if FLAGS.attention_option is not None:
(attention_keys, attention_values, _,
attention_construct_fn) = attention_utils.prepare_attention(
hidden_vector_encodings,
FLAGS.attention_option,
num_units=gen_decoder_rnn_size,
reuse=reuse)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.gen_vd_keep_prob, hparams.gen_rnn_size)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
if not FLAGS.seq2seq_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.matrix_transpose(embedding)
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
# TODO(adai): Perhaps append IMDB labels placeholder to input at
# each time point.
rnn_outs = []
fake = None
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the Decoder.
if t == 0:
# Always provide the real input at t = 0.
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or FLAGS.gen_training_strategy == 'cross_entropy':
rnn_inp = real_rnn_inp
else:
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if FLAGS.attention_option is not None:
rnn_out = attention_construct_fn(rnn_out, attention_keys,
attention_values)
if is_training:
rnn_out *= output_mask
rnn_outs.append(rnn_out)
if FLAGS.gen_training_strategy != 'cross_entropy':
logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b)
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
if FLAGS.use_gen_mode:
fake = categorical.mode()
else:
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
output = tf.where(targets_present[:, t], real, fake)
else:
real = targets[:, t]
logit = tf.zeros(tf.stack([FLAGS.batch_size, FLAGS.vocab_size]))
log_prob = tf.zeros(tf.stack([FLAGS.batch_size]))
output = real
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
if FLAGS.gen_training_strategy == 'cross_entropy':
logits = tf.nn.bias_add(
tf.matmul(
tf.reshape(tf.stack(rnn_outs, 1), [-1, gen_decoder_rnn_size]),
softmax_w), softmax_b)
logits = tf.reshape(logits,
[-1, FLAGS.sequence_length, FLAGS.vocab_size])
else:
logits = tf.stack(logits, axis=1)
return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
def dis_encoder(hparams, masked_inputs, is_training, reuse=None,
embedding=None):
"""Define the Discriminator encoder. Reads in the masked inputs for context
and produces the hidden states of the encoder."""
with tf.variable_scope('encoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
hidden_states = []
missing_embedding = tf.get_variable('missing_embedding',
[1, hparams.dis_rnn_size])
embedding = tf.concat([embedding, missing_embedding], axis=0)
masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = masked_rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
if is_training:
rnn_out *= output_mask
hidden_states.append(rnn_out)
final_state = state_dis
return (tf.stack(hidden_states, axis=1), final_state)
def dis_decoder(hparams,
sequence,
encoding_state,
is_training,
reuse=None,
embedding=None):
"""Define the Discriminator decoder. Read in the sequence and predict
at each time point."""
sequence = tf.cast(sequence, tf.int32)
with tf.variable_scope('decoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
# Hidden encoder states.
hidden_vector_encodings = encoding_state[0]
# Carry forward the final state tuple from the encoder.
# State tuples.
state = encoding_state[1]
if FLAGS.attention_option is not None:
(attention_keys, attention_values, _,
attention_construct_fn) = attention_utils.prepare_attention(
hidden_vector_encodings,
FLAGS.attention_option,
num_units=hparams.dis_rnn_size,
reuse=reuse)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
with tf.variable_scope('rnn') as vs:
predictions = []
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state = cell_dis(rnn_in, state)
if FLAGS.attention_option is not None:
rnn_out = attention_construct_fn(rnn_out, attention_keys,
attention_values)
if is_training:
rnn_out *= output_mask
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
def discriminator(hparams,
inputs,
targets_present,
sequence,
is_training,
reuse=None):
"""Define the Discriminator graph."""
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
else:
# Explicitly share the embedding.
with tf.variable_scope('dis/decoder/rnn', reuse=reuse):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
# Mask the input sequence.
masked_inputs = transform_input_with_is_missing_token(inputs, targets_present)
# Confirm masking.
masked_inputs = tf.Print(
masked_inputs, [inputs, targets_present, masked_inputs, sequence],
message='inputs, targets_present, masked_inputs, sequence',
summarize=10)
with tf.variable_scope('dis', reuse=reuse):
encoder_states = dis_encoder(
hparams,
masked_inputs,
is_training=is_training,
reuse=reuse,
embedding=embedding)
predictions = dis_decoder(
hparams,
sequence,
encoder_states,
is_training=is_training,
reuse=reuse,
embedding=embedding)
# if FLAGS.baseline_method == 'critic':
# with tf.variable_scope('critic', reuse=reuse) as critic_scope:
# values = tf.contrib.layers.linear(rnn_outs, 1, scope=critic_scope)
# values = tf.squeeze(values, axis=2)
# else:
# values = None
return predictions
# TODO(adai): IMDB labels placeholder to encoder/decoder.
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph."""
with tf.variable_scope('gen', reuse=reuse):
encoder_states, initial_state, final_state = gen_encoder(
hparams, inputs, targets_present, is_training=is_training, reuse=reuse)
stacked_sequence, stacked_logits, stacked_log_probs = gen_decoder(
hparams,
inputs,
targets,
targets_present,
encoder_states,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
return (stacked_sequence, stacked_logits, stacked_log_probs, initial_state,
final_state, encoder_states)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from models import attention_utils
FLAGS = tf.app.flags.FLAGS
def transform_input_with_is_missing_token(inputs, targets_present):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing = tf.constant(FLAGS.vocab_size,
dtype=tf.int32,
shape=[FLAGS.batch_size, FLAGS.sequence_length])
# The 0th input will always be present to MaskGAN.
zeroth_input_present = tf.constant(True, tf.bool, shape=[FLAGS.batch_size, 1])
# Input present mask.
inputs_present = tf.concat(
[zeroth_input_present, targets_present[:, :-1]], axis=1)
transformed_input = tf.where(inputs_present, inputs, input_missing)
return transformed_input
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
with tf.variable_scope('encoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(hparams.gen_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell.zero_state(FLAGS.batch_size, tf.float32)
# Add a missing token for inputs not present.
real_inputs = inputs
masked_inputs = transform_input_with_is_missing_token(inputs,
targets_present)
with tf.variable_scope('rnn'):
hidden_states = []
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
missing_embedding = tf.get_variable('missing_embedding',
[1, hparams.gen_rnn_size])
embedding = tf.concat([embedding, missing_embedding], axis=0)
real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs)
masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs)
if is_training and FLAGS.keep_prob < 1:
masked_rnn_inputs = tf.nn.dropout(masked_rnn_inputs, FLAGS.keep_prob)
state = initial_state
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_inp = masked_rnn_inputs[:, t]
rnn_out, state = cell(rnn_inp, state)
hidden_states.append(rnn_out)
final_masked_state = state
hidden_states = tf.stack(hidden_states, axis=1)
# Produce the RNN state had the model operated only
# over real data.
real_state = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
# RNN.
rnn_inp = real_rnn_inputs[:, t]
rnn_out, real_state = cell(rnn_inp, real_state)
final_state = real_state
return (hidden_states, final_masked_state), initial_state, final_state
def gen_decoder(hparams,
inputs,
targets,
targets_present,
encoding_state,
is_training,
is_validating,
reuse=None):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size = hparams.gen_rnn_size
with tf.variable_scope('decoder', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(gen_decoder_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
# Hidden encoder states.
hidden_vector_encodings = encoding_state[0]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen = encoding_state[1]
if FLAGS.attention_option is not None:
(attention_keys, attention_values, _,
attention_construct_fn) = attention_utils.prepare_attention(
hidden_vector_encodings,
FLAGS.attention_option,
num_units=gen_decoder_rnn_size,
reuse=reuse)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.matrix_transpose(embedding)
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
rnn_outs = []
fake = None
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the Decoder.
if t == 0:
# Always provide the real input at t = 0.
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or FLAGS.gen_training_strategy == 'cross_entropy':
rnn_inp = real_rnn_inp
else:
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if FLAGS.attention_option is not None:
rnn_out = attention_construct_fn(rnn_out, attention_keys,
attention_values)
rnn_outs.append(rnn_out)
if FLAGS.gen_training_strategy != 'cross_entropy':
logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b)
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
output = tf.where(targets_present[:, t], real, fake)
else:
batch_size = tf.shape(rnn_out)[0]
logit = tf.zeros(tf.stack([batch_size, FLAGS.vocab_size]))
log_prob = tf.zeros(tf.stack([batch_size]))
output = targets[:, t]
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
if FLAGS.gen_training_strategy == 'cross_entropy':
logits = tf.nn.bias_add(
tf.matmul(
tf.reshape(tf.stack(rnn_outs, 1), [-1, gen_decoder_rnn_size]),
softmax_w), softmax_b)
logits = tf.reshape(logits,
[-1, FLAGS.sequence_length, FLAGS.vocab_size])
else:
logits = tf.stack(logits, axis=1)
return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph."""
with tf.variable_scope('gen', reuse=reuse):
encoder_states, initial_state, final_state = gen_encoder(
hparams, inputs, targets_present, is_training=is_training, reuse=reuse)
stacked_sequence, stacked_logits, stacked_log_probs = gen_decoder(
hparams,
inputs,
targets,
targets_present,
encoder_states,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
return (stacked_sequence, stacked_logits, stacked_log_probs, initial_state,
final_state)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def print_config(config):
print("-" * 10, "Configuration Specs", "-" * 10)
for item in dir(config):
if list(item)[0] != "_":
print(item, getattr(config, item))
print("-" * 29)
class AlienConfig2(object):
"""Base 8 740 shared embeddings, gets 64.0 (mean: std: min: max: )."""
init_scale = 0.05
learning_rate = 1.0
max_grad_norm = 10
num_layers = 2
num_steps = 25
hidden_size = 740
max_epoch = 70
max_max_epoch = 250
keep_prob = [1 - 0.15, 1 - 0.45]
lr_decay = 0.95
batch_size = 20
vocab_size = 10000
weight_decay = 1e-4
share_embeddings = True
cell = "alien"
dropout_type = "variational"
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import tensorflow as tf
flags = tf.flags
FLAGS = tf.app.flags.FLAGS
LSTMTuple = collections.namedtuple('LSTMTuple', ['c', 'h'])
def cell_depth(num):
num /= 2
val = np.log2(1 + num)
assert abs(val - int(val)) == 0
return int(val)
class GenericMultiRNNCell(tf.contrib.rnn.RNNCell):
"""More generic version of MultiRNNCell that allows you to pass in a dropout mask"""
def __init__(self, cells):
"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
cells: list of RNNCells that will be composed in this order.
state_is_tuple: If True, accepted and returned states are n-tuples, where
`n = len(cells)`. If False, the states are all
concatenated along the column axis. This latter behavior will soon be
deprecated.
Raises:
ValueError: if cells is empty (not allowed), or at least one of the cells
returns a state tuple but the flag `state_is_tuple` is `False`.
"""
self._cells = cells
@property
def state_size(self):
return tuple(cell.state_size for cell in self._cells)
@property
def output_size(self):
return self._cells[-1].output_size
def __call__(self, inputs, state, input_masks=None, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with tf.variable_scope(scope or type(self).__name__):
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with tf.variable_scope('Cell%d' % i):
cur_state = state[i]
if input_masks is not None:
cur_inp *= input_masks[i]
cur_inp, new_state = cell(cur_inp, cur_state)
new_states.append(new_state)
new_states = tuple(new_states)
return cur_inp, new_states
class AlienRNNBuilder(tf.contrib.rnn.RNNCell):
def __init__(self, num_units, params, additional_params, base_size):
self.num_units = num_units
self.cell_create_index = additional_params[0]
self.cell_inject_index = additional_params[1]
self.base_size = base_size
self.cell_params = params[
-2:] # Cell injection parameters are always the last two
params = params[:-2]
self.depth = cell_depth(len(params))
self.params = params
self.units_per_layer = [2**i for i in range(self.depth)
][::-1] # start with the biggest layer
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
definition1 = ['add', 'elem_mult', 'max']
definition2 = [tf.identity, tf.tanh, tf.sigmoid, tf.nn.relu, tf.sin]
layer_outputs = [[] for _ in range(self.depth)]
with tf.variable_scope('rnn_builder'):
curr_index = 0
c, h = state
# Run all dense matrix multiplications at once
big_h_mat = tf.get_variable(
'big_h_mat', [self.num_units,
self.base_size * self.num_units], tf.float32)
big_inputs_mat = tf.get_variable(
'big_inputs_mat', [self.num_units,
self.base_size * self.num_units], tf.float32)
big_h_output = tf.matmul(h, big_h_mat)
big_inputs_output = tf.matmul(inputs, big_inputs_mat)
h_splits = tf.split(big_h_output, self.base_size, axis=1)
inputs_splits = tf.split(big_inputs_output, self.base_size, axis=1)
for layer_num, units in enumerate(self.units_per_layer):
for unit_num in range(units):
with tf.variable_scope(
'layer_{}_unit_{}'.format(layer_num, unit_num)):
if layer_num == 0:
prev1_mat = h_splits[unit_num]
prev2_mat = inputs_splits[unit_num]
else:
prev1_mat = layer_outputs[layer_num - 1][2 * unit_num]
prev2_mat = layer_outputs[layer_num - 1][2 * unit_num + 1]
if definition1[self.params[curr_index]] == 'add':
output = prev1_mat + prev2_mat
elif definition1[self.params[curr_index]] == 'elem_mult':
output = prev1_mat * prev2_mat
elif definition1[self.params[curr_index]] == 'max':
output = tf.maximum(prev1_mat, prev2_mat)
if curr_index / 2 == self.cell_create_index: # Take the new cell before the activation
new_c = tf.identity(output)
output = definition2[self.params[curr_index + 1]](output)
if curr_index / 2 == self.cell_inject_index:
if definition1[self.cell_params[0]] == 'add':
output += c
elif definition1[self.cell_params[0]] == 'elem_mult':
output *= c
elif definition1[self.cell_params[0]] == 'max':
output = tf.maximum(output, c)
output = definition2[self.cell_params[1]](output)
layer_outputs[layer_num].append(output)
curr_index += 2
new_h = layer_outputs[-1][-1]
return new_h, LSTMTuple(new_c, new_h)
@property
def state_size(self):
return LSTMTuple(self.num_units, self.num_units)
@property
def output_size(self):
return self.num_units
class Alien(AlienRNNBuilder):
"""Base 8 Cell."""
def __init__(self, num_units):
params = [
0, 2, 0, 3, 0, 2, 1, 3, 0, 1, 0, 2, 0, 1, 0, 2, 1, 1, 0, 1, 1, 1, 0, 2,
1, 0, 0, 1, 1, 1, 0, 1
]
additional_params = [12, 8]
base_size = 8
super(Alien, self).__init__(num_units, params, additional_params, base_size)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variational Dropout."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def generate_dropout_masks(keep_prob, shape, amount):
masks = []
for _ in range(amount):
dropout_mask = tf.random_uniform(shape) + (keep_prob)
dropout_mask = tf.floor(dropout_mask) / (keep_prob)
masks.append(dropout_mask)
return masks
def generate_variational_dropout_masks(hparams, keep_prob):
[batch_size, num_steps, size, num_layers] = [
FLAGS.batch_size, FLAGS.sequence_length, hparams.gen_rnn_size,
hparams.gen_num_layers
]
if len(keep_prob) == 2:
emb_keep_prob = keep_prob[0] # keep prob for embedding matrix
h2h_keep_prob = emb_keep_prob # keep prob for hidden to hidden connections
h2i_keep_prob = keep_prob[1] # keep prob for hidden to input connections
out_keep_prob = h2i_keep_prob # keep probability for output state
else:
emb_keep_prob = keep_prob[0] # keep prob for embedding matrix
h2h_keep_prob = keep_prob[1] # keep prob for hidden to hidden connections
h2i_keep_prob = keep_prob[2] # keep prob for hidden to input connections
out_keep_prob = keep_prob[3] # keep probability for output state
h2i_masks = [] # Masks for input to recurrent connections
h2h_masks = [] # Masks for recurrent to recurrent connections
# Input word dropout mask
emb_masks = generate_dropout_masks(emb_keep_prob, [num_steps, 1], batch_size)
output_mask = generate_dropout_masks(out_keep_prob, [batch_size, size], 1)[0]
h2i_masks = generate_dropout_masks(h2i_keep_prob, [batch_size, size],
num_layers)
h2h_masks = generate_dropout_masks(h2h_keep_prob, [batch_size, size],
num_layers)
return h2h_masks, h2i_masks, emb_masks, output_mask
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pretraining functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from data import imdb_loader
from data import ptb_loader
# Data.
from model_utils import model_utils
from models import evaluation_utils
tf.app.flags.DEFINE_integer(
'gen_pretrain_steps', None,
'The number of steps to pretrain the generator with cross entropy loss.')
tf.app.flags.DEFINE_integer(
'dis_pretrain_steps', None,
'The number of steps to pretrain the discriminator.')
FLAGS = tf.app.flags.FLAGS
def pretrain_generator(sv, sess, model, data, log, id_to_word,
data_ngram_counts, is_chief):
"""Pretrain the generator with classic language modeling training."""
print('\nPretraining generator for %d steps.' % FLAGS.gen_pretrain_steps)
log.write(
'\nPretraining generator for %d steps.\n' % FLAGS.gen_pretrain_steps)
is_pretraining = True
while is_pretraining:
costs = 0.
iters = 0
if FLAGS.data_set == 'ptb':
iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length,
FLAGS.epoch_size_override)
elif FLAGS.data_set == 'imdb':
iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length)
for x, y, _ in iterator:
# For pretraining with cross entropy loss, we have all tokens in the
# forward sequence present (all True).
model_utils.assign_percent_real(sess, model.percent_real_update,
model.new_rate, 1.0)
p = np.ones(shape=[FLAGS.batch_size, FLAGS.sequence_length], dtype=bool)
pretrain_feed = {model.inputs: x, model.targets: y, model.present: p}
[losses, cost_eval, _, step] = sess.run(
[
model.fake_cross_entropy_losses, model.avg_log_perplexity,
model.gen_pretrain_op, model.global_step
],
feed_dict=pretrain_feed)
costs += cost_eval
iters += FLAGS.sequence_length
# Calulate rolling perplexity.
perplexity = np.exp(costs / iters)
# Summaries.
if is_chief and step % FLAGS.summaries_every == 0:
# Graph summaries.
summary_str = sess.run(
model.merge_summaries_op, feed_dict=pretrain_feed)
sv.SummaryComputed(sess, summary_str)
# Additional summary.
for n, data_ngram_count in data_ngram_counts.iteritems():
avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
int(n))
summary_percent_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/%s-grams_percent_correct' % n,
simple_value=avg_percent_captured)
])
sv.SummaryComputed(sess, summary_percent_str, global_step=step)
summary_perplexity_str = tf.Summary(value=[
tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
])
sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
# Printing and logging
if is_chief and step % FLAGS.print_every == 0:
print('global_step: %d' % step)
print(' generator loss: %.3f' % np.mean(losses))
print(' perplexity: %.3f' % perplexity)
log.write('global_step: %d\n' % step)
log.write(' generator loss: %.3f\n' % np.mean(losses))
log.write(' perplexity: %.3f\n' % perplexity)
for n, data_ngram_count in data_ngram_counts.iteritems():
avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
int(n))
print(' percent of %s-grams captured: %.3f.\n' %
(n, avg_percent_captured))
log.write(' percent of %s-grams captured: %.3f.\n\n' %
(n, avg_percent_captured))
evaluation_utils.generate_logs(sess, model, log, id_to_word,
pretrain_feed)
if step >= FLAGS.gen_pretrain_steps:
is_pretraining = False
break
return
def pretrain_discriminator(sv, sess, model, data, log, id_to_word,
data_ngram_counts, is_chief):
print('\nPretraining discriminator for %d steps.' % FLAGS.dis_pretrain_steps)
log.write(
'\nPretraining discriminator for %d steps.\n' % FLAGS.dis_pretrain_steps)
is_pretraining = True
while is_pretraining:
cumulative_costs = 0.
iters = 0
if FLAGS.data_set == 'ptb':
iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length,
FLAGS.epoch_size_override)
elif FLAGS.data_set == 'imdb':
iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length)
for x, y, _ in iterator:
is_present_rate = FLAGS.is_present_rate
# is_present_rate = np.random.uniform(low=0.0, high=1.0)
model_utils.assign_percent_real(sess, model.percent_real_update,
model.new_rate, is_present_rate)
# Randomly mask out tokens.
p = model_utils.generate_mask()
pretrain_feed = {model.inputs: x, model.targets: y, model.present: p}
[_, dis_loss_eval, gen_log_perplexity_eval, step] = sess.run(
[
model.dis_pretrain_op, model.dis_loss, model.avg_log_perplexity,
model.global_step
],
feed_dict=pretrain_feed)
cumulative_costs += gen_log_perplexity_eval
iters += 1
# Calulate rolling perplexity.
perplexity = np.exp(cumulative_costs / iters)
# Summaries.
if is_chief and step % FLAGS.summaries_every == 0:
# Graph summaries.
summary_str = sess.run(
model.merge_summaries_op, feed_dict=pretrain_feed)
sv.SummaryComputed(sess, summary_str)
# Additional summary.
for n, data_ngram_count in data_ngram_counts.iteritems():
avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
int(n))
summary_percent_str = tf.Summary(value=[
tf.Summary.Value(
tag='general/%s-grams_percent_correct' % n,
simple_value=avg_percent_captured)
])
sv.SummaryComputed(sess, summary_percent_str, global_step=step)
summary_perplexity_str = tf.Summary(value=[
tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
])
sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
# Printing and logging
if is_chief and step % FLAGS.print_every == 0:
print('global_step: %d' % step)
print(' discriminator loss: %.3f' % dis_loss_eval)
print(' perplexity: %.3f' % perplexity)
log.write('global_step: %d\n' % step)
log.write(' discriminator loss: %.3f\n' % dis_loss_eval)
log.write(' perplexity: %.3f\n' % perplexity)
for n, data_ngram_count in data_ngram_counts.iteritems():
avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
int(n))
print(' percent of %s-grams captured: %.3f.\n' %
(n, avg_percent_captured))
log.write(' percent of %s-grams captured: %.3f.\n\n' %
(n, avg_percent_captured))
evaluation_utils.generate_logs(sess, model, log, id_to_word,
pretrain_feed)
if step >= FLAGS.dis_pretrain_steps + int(FLAGS.gen_pretrain_steps or 0):
is_pretraining = False
break
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment