Commit 87fae3f7 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Added new MaskGAN model.

parent 813dd09a
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None):
"""Define the Discriminator graph."""
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_fwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
cell_bwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
# print initial_state
# print cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
if initial_state:
state_fwd = [[tf.identity(x) for x in inner_initial_state]
for inner_initial_state in initial_state]
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
else:
state_fwd = cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob,
2 * hparams.dis_rnn_size)
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
rnn_inputs = tf.unstack(rnn_inputs, axis=1)
with tf.variable_scope('rnn') as vs:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
cell_fwd, cell_bwd, rnn_inputs, state_fwd, state_bwd, scope=vs)
if is_training:
outputs *= output_mask
# Prediction is linear output for Discriminator.
predictions = tf.contrib.layers.linear(outputs, 1, scope=vs)
predictions = tf.transpose(predictions, [1, 0, 2])
if FLAGS.baseline_method == 'critic':
with tf.variable_scope('critic', reuse=reuse) as critic_scope:
values = tf.contrib.layers.linear(outputs, 1, scope=critic_scope)
values = tf.transpose(values, [1, 0, 2])
return tf.squeeze(predictions, axis=2), tf.squeeze(values, axis=2)
else:
return tf.squeeze(predictions, axis=2), None
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the bidirectional Discriminator graph."""
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_fwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
cell_bwd = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_fwd = cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
rnn_inputs = tf.unstack(rnn_inputs, axis=1)
with tf.variable_scope('rnn') as vs:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
cell_fwd, cell_bwd, rnn_inputs, state_fwd, state_bwd, scope=vs)
# Prediction is linear output for Discriminator.
predictions = tf.contrib.layers.linear(outputs, 1, scope=vs)
predictions = tf.transpose(predictions, [1, 0, 2])
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple CNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
del is_training
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension.")
with tf.variable_scope("gen/rnn", reuse=True):
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.gen_rnn_size])
dis_filter_sizes = [3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
with tf.variable_scope("dis", reuse=reuse):
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.dis_rnn_size])
cnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
# Create a convolution layer for each filter size
conv_outputs = []
for filter_size in dis_filter_sizes:
with tf.variable_scope("conv-%s" % filter_size):
# Convolution Layer
filter_shape = [
filter_size, hparams.dis_rnn_size, hparams.dis_num_filters
]
W = tf.get_variable(
name="W", initializer=tf.truncated_normal(filter_shape, stddev=0.1))
b = tf.get_variable(
name="b",
initializer=tf.constant(0.1, shape=[hparams.dis_num_filters]))
conv = tf.nn.conv1d(
cnn_inputs, W, stride=1, padding="SAME", name="conv")
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
conv_outputs.append(h)
# Combine all the pooled features
dis_num_filters_total = hparams.dis_num_filters * len(dis_filter_sizes)
h_conv = tf.concat(conv_outputs, axis=2)
h_conv_flat = tf.reshape(h_conv, [-1, dis_num_filters_total])
# Add dropout
with tf.variable_scope("dropout"):
h_drop = tf.nn.dropout(h_conv_flat, FLAGS.keep_prob)
with tf.variable_scope("fully_connected"):
fc = tf.contrib.layers.fully_connected(
h_drop, num_outputs=dis_num_filters_total / 2)
# Final (unnormalized) scores and predictions
with tf.variable_scope("output"):
W = tf.get_variable(
"W",
shape=[dis_num_filters_total / 2, 1],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(name="b", initializer=tf.constant(0.1, shape=[1]))
predictions = tf.nn.xw_plus_b(fc, W, b, name="predictions")
predictions = tf.reshape(
predictions, shape=[FLAGS.batch_size, FLAGS.sequence_length])
return predictions
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critic model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def critic_seq2seq_vd_derivative(hparams, sequence, is_training, reuse=None):
"""Define the Critic graph which is derived from the seq2seq_vd
Discriminator. This will be initialized with the same parameters as the
language model and will share the forward RNN components with the
Discriminator. This estimates the V(s_t), where the state
s_t = x_0,...,x_t-1.
"""
assert FLAGS.discriminator_model == 'seq2seq_vd'
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
else:
with tf.variable_scope('dis/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
with tf.variable_scope(
'dis/decoder/rnn/multi_rnn_cell', reuse=True) as dis_scope:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=True)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_critic = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
with tf.variable_scope('critic', reuse=reuse):
state_dis = cell_critic.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
with tf.variable_scope('rnn') as vs:
values = []
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
if t == 0:
rnn_in = tf.zeros_like(rnn_inputs[:, 0])
else:
rnn_in = rnn_inputs[:, t - 1]
rnn_out, state_dis = cell_critic(rnn_in, state_dis, scope=dis_scope)
if is_training:
rnn_out *= output_mask
# Prediction is linear output for Discriminator.
value = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
values.append(value)
values = tf.stack(values, axis=1)
return tf.squeeze(values, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import Counter
# Dependency imports
import numpy as np
from scipy.special import expit
import tensorflow as tf
from model_utils import helper
from model_utils import n_gram
FLAGS = tf.app.flags.FLAGS
def print_and_log_losses(log, step, is_present_rate, avg_dis_loss,
avg_gen_loss):
"""Prints and logs losses to the log file.
Args:
log: GFile for logs.
step: Global step.
is_present_rate: Current masking rate.
avg_dis_loss: List of Discriminator losses.
avg_gen_loss: List of Generator losses.
"""
print('global_step: %d' % step)
print(' is_present_rate: %.3f' % is_present_rate)
print(' D train loss: %.5f' % np.mean(avg_dis_loss))
print(' G train loss: %.5f' % np.mean(avg_gen_loss))
log.write('\nglobal_step: %d\n' % step)
log.write((' is_present_rate: %.3f\n' % is_present_rate))
log.write(' D train loss: %.5f\n' % np.mean(avg_dis_loss))
log.write(' G train loss: %.5f\n' % np.mean(avg_gen_loss))
def print_and_log(log, id_to_word, sequence_eval, max_num_to_print=5):
"""Helper function for printing and logging evaluated sequences."""
indices_arr = np.asarray(sequence_eval)
samples = helper.convert_to_human_readable(id_to_word, indices_arr,
max_num_to_print)
for i, sample in enumerate(samples):
print('Sample', i, '. ', sample)
log.write('\nSample ' + str(i) + '. ' + sample)
log.write('\n')
print('\n')
log.flush()
return samples
def zip_seq_pred_crossent(id_to_word, sequences, predictions, cross_entropy):
"""Zip together the sequences, predictions, cross entropy."""
indices = np.asarray(sequences)
batch_of_metrics = []
for ind_batch, pred_batch, crossent_batch in zip(indices, predictions,
cross_entropy):
metrics = []
for index, pred, crossent in zip(ind_batch, pred_batch, crossent_batch):
metrics.append([str(id_to_word[index]), pred, crossent])
batch_of_metrics.append(metrics)
return batch_of_metrics
def zip_metrics(indices, *args):
"""Zip together the indices matrices with the provided metrics matrices."""
batch_of_metrics = []
for metrics_batch in zip(indices, *args):
metrics = []
for m in zip(*metrics_batch):
metrics.append(m)
batch_of_metrics.append(metrics)
return batch_of_metrics
def print_formatted(present, id_to_word, log, batch_of_tuples):
"""Print and log metrics."""
num_cols = len(batch_of_tuples[0][0])
repeat_float_format = '{:<12.3f} '
repeat_str_format = '{:<13}'
format_str = ''.join(
['[{:<1}] {:<20}',
str(repeat_float_format * (num_cols - 1))])
# TODO(liamfedus): Generalize the logging. This is sloppy.
header_format_str = ''.join(
['[{:<1}] {:<20}',
str(repeat_str_format * (num_cols - 1))])
header_str = header_format_str.format('p', 'Word', 'p(real)', 'log-perp',
'log(p(a))', 'r', 'R=V*(s)', 'b=V(s)',
'A(a,s)')
for i, batch in enumerate(batch_of_tuples):
print(' Sample: %d' % i)
log.write(' Sample %d.\n' % i)
print(' ', header_str)
log.write(' ' + str(header_str) + '\n')
for j, t in enumerate(batch):
t = list(t)
t[0] = id_to_word[t[0]]
buffer_str = format_str.format(int(present[i][j]), *t)
print(' ', buffer_str)
log.write(' ' + str(buffer_str) + '\n')
log.flush()
def generate_RL_logs(sess, model, log, id_to_word, feed):
"""Generate complete logs while running with REINFORCE."""
# Impute Sequences.
[
p,
fake_sequence_eval,
fake_predictions_eval,
_,
fake_cross_entropy_losses_eval,
_,
fake_log_probs_eval,
fake_rewards_eval,
fake_baselines_eval,
cumulative_rewards_eval,
fake_advantages_eval,
] = sess.run(
[
model.present,
model.fake_sequence,
model.fake_predictions,
model.real_predictions,
model.fake_cross_entropy_losses,
model.fake_logits,
model.fake_log_probs,
model.fake_rewards,
model.fake_baselines,
model.cumulative_rewards,
model.fake_advantages,
],
feed_dict=feed)
indices = np.asarray(fake_sequence_eval)
# Convert Discriminator linear layer to probability.
fake_prob_eval = expit(fake_predictions_eval)
# Add metrics.
fake_tuples = zip_metrics(indices, fake_prob_eval,
fake_cross_entropy_losses_eval, fake_log_probs_eval,
fake_rewards_eval, cumulative_rewards_eval,
fake_baselines_eval, fake_advantages_eval)
# real_tuples = zip_metrics(indices, )
# Print forward sequences.
tuples_to_print = fake_tuples[:FLAGS.max_num_to_print]
print_formatted(p, id_to_word, log, tuples_to_print)
print('Samples')
log.write('Samples\n')
samples = print_and_log(log, id_to_word, fake_sequence_eval,
FLAGS.max_num_to_print)
return samples
def generate_logs(sess, model, log, id_to_word, feed):
"""Impute Sequences using the model for a particular feed and send it to
logs."""
# Impute Sequences.
[
p, sequence_eval, fake_predictions_eval, fake_cross_entropy_losses_eval,
fake_logits_eval
] = sess.run(
[
model.present, model.fake_sequence, model.fake_predictions,
model.fake_cross_entropy_losses, model.fake_logits
],
feed_dict=feed)
# Convert Discriminator linear layer to probability.
fake_prob_eval = expit(fake_predictions_eval)
# Forward Masked Tuples.
fake_tuples = zip_seq_pred_crossent(id_to_word, sequence_eval, fake_prob_eval,
fake_cross_entropy_losses_eval)
tuples_to_print = fake_tuples[:FLAGS.max_num_to_print]
if FLAGS.print_verbose:
print('fake_logits_eval')
print(fake_logits_eval)
for i, batch in enumerate(tuples_to_print):
print(' Sample %d.' % i)
log.write(' Sample %d.\n' % i)
for j, pred in enumerate(batch):
buffer_str = ('[{:<1}] {:<20} {:<7.3f} {:<7.3f}').format(
int(p[i][j]), pred[0], pred[1], pred[2])
print(' ', buffer_str)
log.write(' ' + str(buffer_str) + '\n')
log.flush()
print('Samples')
log.write('Samples\n')
samples = print_and_log(log, id_to_word, sequence_eval,
FLAGS.max_num_to_print)
return samples
def create_merged_ngram_dictionaries(indices, n):
"""Generate a single dictionary for the full batch.
Args:
indices: List of lists of indices.
n: Degree of n-grams.
Returns:
Dictionary of hashed(n-gram tuples) to counts in the batch of indices.
"""
ngram_dicts = []
for ind in indices:
ngrams = n_gram.find_all_ngrams(ind, n=n)
ngram_counts = n_gram.construct_ngrams_dict(ngrams)
ngram_dicts.append(ngram_counts)
merged_gen_dict = Counter()
for ngram_dict in ngram_dicts:
merged_gen_dict += Counter(ngram_dict)
return merged_gen_dict
def sequence_ngram_evaluation(sess, sequence, log, feed, data_ngram_count, n):
"""Calculates the percent of ngrams produced in the sequence is present in
data_ngram_count.
Args:
sess: tf.Session.
sequence: Sequence Tensor from the MaskGAN model.
log: gFile log.
feed: Feed to evaluate.
data_ngram_count: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
Returns:
avg_percent_captured: Percent of produced ngrams that appear in the
data_ngram_count.
"""
del log
# Impute sequence.
[sequence_eval] = sess.run([sequence], feed_dict=feed)
indices = sequence_eval
# Retrieve the counts across the batch of indices.
gen_ngram_counts = create_merged_ngram_dictionaries(
indices, n=n)
return n_gram.percent_unique_ngrams_in_train(data_ngram_count,
gen_ngram_counts)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple FNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
del is_training
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension.")
with tf.variable_scope("gen/rnn", reuse=True):
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope("dis", reuse=reuse):
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable("embedding",
[FLAGS.vocab_size, hparams.dis_rnn_size])
embeddings = tf.nn.embedding_lookup(embedding, sequence)
# Input matrices.
W = tf.get_variable(
"W",
initializer=tf.truncated_normal(
shape=[3 * hparams.dis_embedding_dim, hparams.dis_hidden_dim],
stddev=0.1))
b = tf.get_variable(
"b", initializer=tf.constant(0.1, shape=[hparams.dis_hidden_dim]))
# Output matrices.
W_out = tf.get_variable(
"W_out",
initializer=tf.truncated_normal(
shape=[hparams.dis_hidden_dim, 1], stddev=0.1))
b_out = tf.get_variable("b_out", initializer=tf.constant(0.1, shape=[1]))
predictions = []
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
inp = embeddings[:, t]
if t > 0:
past_inp = tf.unstack(embeddings[:, 0:t], axis=1)
avg_past_inp = tf.add_n(past_inp) / len(past_inp)
else:
avg_past_inp = tf.zeros_like(inp)
if t < FLAGS.sequence_length:
future_inp = tf.unstack(embeddings[:, t:], axis=1)
avg_future_inp = tf.add_n(future_inp) / len(future_inp)
else:
avg_future_inp = tf.zeros_like(inp)
# Cumulative input.
concat_inp = tf.concat([avg_past_inp, inp, avg_future_inp], axis=1)
# Hidden activations.
hidden = tf.nn.relu(tf.nn.xw_plus_b(concat_inp, W, b, name="scores"))
# Add dropout
with tf.variable_scope("dropout"):
hidden = tf.nn.dropout(hidden, FLAGS.keep_prob)
# Output.
output = tf.nn.xw_plus_b(hidden, W_out, b_out, name="output")
predictions.append(output)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# ZoneoutWrapper.
from regularization import zoneout
FLAGS = tf.app.flags.FLAGS
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.warning(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
init_scale = 0.05
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.gen_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.get_variable('softmax_w',
[hparams.gen_rnn_size, FLAGS.vocab_size])
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the target at the last time-step was present, read in the real.
# If the target at the last time-step was not present, read in the fake.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# Use teacher forcing.
if (is_training and
FLAGS.gen_training_strategy == 'cross_entropy') or is_validating:
rnn_inp = real_rnn_inp
else:
# Note that targets_t-1 == inputs_(t)
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
# Fake sample.
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the target.
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Append to lists.
sequence.append(output)
logits.append(logit)
log_probs.append(log_prob)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph.
Args:
hparams: Hyperparameters for the MaskGAN.
FLAGS: Current flags.
sequence: [FLAGS.batch_size, FLAGS.sequence_length]
is_training:
reuse
Returns:
predictions:
"""
tf.logging.warning(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.dis_rnn_size, reuse=reuse)
attn_cell = lstm_cell
if FLAGS.zoneout_drop_prob > 0.0:
def attn_cell():
return zoneout.ZoneoutWrapper(
lstm_cell(),
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
# NAS Code..
from nas_utils import configs
from nas_utils import custom_cell
from nas_utils import variational_dropout
FLAGS = tf.app.flags.FLAGS
def get_config():
return configs.AlienConfig2()
LSTMTuple = collections.namedtuple('LSTMTuple', ['c', 'h'])
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.info(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
config = get_config()
config.keep_prob = [hparams.gen_nas_keep_prob_0, hparams.gen_nas_keep_prob_1]
configs.print_config(config)
init_scale = config.init_scale
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, _, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
output_mask = None
cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.matrix_transpose(embedding)
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
if is_training:
state_gen = list(state_gen)
for layer_num, per_layer_state in enumerate(state_gen):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state_gen[layer_num] = per_layer_state
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
if is_training:
rnn_out = output_mask * rnn_out
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
tf.logging.info(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
config = get_config()
config.keep_prob = [hparams.dis_nas_keep_prob_0, hparams.dis_nas_keep_prob_1]
configs.print_config(config)
with tf.variable_scope('dis', reuse=reuse):
# Neural architecture search cell.
cell = custom_cell.Alien(config.hidden_size)
if is_training:
[h2h_masks, _, _,
output_mask] = variational_dropout.generate_variational_dropout_masks(
hparams, config.keep_prob)
else:
output_mask = None
cell_dis = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
if is_training:
state_dis = list(state_dis)
for layer_num, per_layer_state in enumerate(state_dis):
per_layer_state = LSTMTuple(
per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
state_dis[layer_num] = per_layer_state
# RNN.
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
if is_training:
rnn_out = output_mask * rnn_out
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from regularization import variational_dropout
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None):
"""Define the Discriminator graph."""
tf.logging.info(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/decoder/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and hparams.dis_vd_keep_prob < 1:
def attn_cell():
return variational_dropout.VariationalDropoutWrapper(
lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
if initial_state:
state_dis = [[tf.identity(x) for x in inner_initial_state]
for inner_initial_state in initial_state]
else:
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
def make_mask(keep_prob, units):
random_tensor = keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
return tf.floor(random_tensor) / keep_prob
if is_training:
output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)
with tf.variable_scope('rnn') as vs:
predictions, rnn_outs = [], []
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
if is_training:
rnn_out *= output_mask
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
rnn_outs.append(rnn_out)
predictions = tf.stack(predictions, axis=1)
if FLAGS.baseline_method == 'critic':
with tf.variable_scope('critic', reuse=reuse) as critic_scope:
rnn_outs = tf.stack(rnn_outs, axis=1)
values = tf.contrib.layers.linear(rnn_outs, 1, scope=critic_scope)
return tf.squeeze(predictions, axis=2), tf.squeeze(values, axis=2)
else:
return tf.squeeze(predictions, axis=2), None
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def generator(hparams,
inputs,
targets,
targets_present,
is_training,
is_validating,
reuse=None):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf.logging.warning(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.')
init_scale = 0.05
initializer = tf.random_uniform_initializer(-init_scale, init_scale)
with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(hparams.gen_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_gen = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.gen_num_layers)],
state_is_tuple=True)
initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn'):
sequence, logits, log_probs = [], [], []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
softmax_w = tf.get_variable('softmax_w',
[hparams.gen_rnn_size, FLAGS.vocab_size])
softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
fake = None
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if t == 0:
# Always provide the real input at t = 0.
state_gen = initial_state
rnn_inp = rnn_inputs[:, t]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else:
real_rnn_inp = rnn_inputs[:, t]
fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if is_validating or (is_training and
FLAGS.gen_training_strategy == 'cross_entropy'):
rnn_inp = real_rnn_inp
else:
rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
fake_rnn_inp)
# RNN.
rnn_out, state_gen = cell_gen(rnn_inp, state_gen)
logit = tf.matmul(rnn_out, softmax_w) + softmax_b
# Real sample.
real = targets[:, t]
categorical = tf.contrib.distributions.Categorical(logits=logit)
fake = categorical.sample()
log_prob = categorical.log_prob(fake)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output = tf.where(targets_present[:, t], real, fake)
# Add to lists.
sequence.append(output)
log_probs.append(log_prob)
logits.append(logit)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen = initial_state
for t in xrange(FLAGS.sequence_length):
tf.get_variable_scope().reuse_variables()
rnn_inp = rnn_inputs[:, t]
# RNN.
rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)
final_state = real_state_gen
return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
log_probs, axis=1), initial_state, final_state)
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the Discriminator graph."""
tf.logging.warning(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.')
sequence = tf.cast(sequence, tf.int32)
with tf.variable_scope('dis', reuse=reuse):
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(hparams.dis_rnn_size,
forget_bias=0.0,
state_is_tuple=True,
reuse=reuse)
attn_cell = lstm_cell
if is_training and FLAGS.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=FLAGS.keep_prob)
cell_dis = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(hparams.dis_num_layers)],
state_is_tuple=True)
state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)
with tf.variable_scope('rnn') as vs:
predictions = []
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
if is_training and FLAGS.keep_prob < 1:
rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)
for t in xrange(FLAGS.sequence_length):
if t > 0:
tf.get_variable_scope().reuse_variables()
rnn_in = rnn_inputs[:, t]
rnn_out, state_dis = cell_dis(rnn_in, state_dis)
# Prediction is linear output for Discriminator.
pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
predictions.append(pred)
predictions = tf.stack(predictions, axis=1)
return tf.squeeze(predictions, axis=2)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def print_config(config):
print("-" * 10, "Configuration Specs", "-" * 10)
for item in dir(config):
if list(item)[0] != "_":
print(item, getattr(config, item))
print("-" * 29)
class AlienConfig2(object):
"""Base 8 740 shared embeddings, gets 64.0 (mean: std: min: max: )."""
init_scale = 0.05
learning_rate = 1.0
max_grad_norm = 10
num_layers = 2
num_steps = 25
hidden_size = 740
max_epoch = 70
max_max_epoch = 250
keep_prob = [1 - 0.15, 1 - 0.45]
lr_decay = 0.95
batch_size = 20
vocab_size = 10000
weight_decay = 1e-4
share_embeddings = True
cell = "alien"
dropout_type = "variational"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment