Unverified Commit 7d16fc45 authored by Andrew M Dai's avatar Andrew M Dai Committed by GitHub
Browse files

Merge pull request #3486 from a-dai/master

Added new MaskGAN model.
parents c5c6eaf2 87fae3f7
......@@ -18,6 +18,7 @@
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
/research/lfads/ @jazcollins @susillo
/research/lm_1b/ @oriolvinyals @panyx0718
/research/maskgan/ @a-dai
/research/namignizer/ @knathanieltucker
/research/neural_gpu/ @lukaszkaiser
/research/neural_programmer/ @arvind2505
......
......@@ -38,6 +38,7 @@ installation](https://www.tensorflow.org/install).
- [lfads](lfads): sequential variational autoencoder for analyzing
neuroscience data.
- [lm_1b](lm_1b): language modeling on the one billion word benchmark.
- [maskgan](maskgan): text generation with GANs.
- [namignizer](namignizer): recognize and generate names.
- [neural_gpu](neural_gpu): highly parallel neural computer.
- [neural_programmer](neural_programmer): neural network augmented with logic
......
# MaskGAN: Better Text Generation via Filling in the ______
Code for [*MaskGAN: Better Text Generation via Filling in the
______*](https://arxiv.org/abs/1801.07736) published at ICLR 2018.
## Requirements
* TensorFlow >= v1.3
## Instructions
Warning: The open-source version of this code is still in the process of being
tested. Pretraining may not work correctly.
For training on PTB:
1. (Optional) Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/.
Instructions WIP.
2. (Optional) Run MaskGAN in MLE pretraining mode:
```bash
python train_mask_gan.py \
--data_dir='/tmp/ptb' \
--batch_size=20 \
--sequence_length=20 \
--base_directory='/tmp/maskGAN' \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95" \
--mode='TRAIN' \
--max_steps=100000 \
--language_model_ckpt_dir=/tmp/pretrain-lm/ \
--generator_model='seq2seq_vd' \
--discriminator_model='rnn_zaremba' \
--is_present_rate=0.5 \
--summaries_every=10 \
--print_every=250 \
--max_num_to_print=3 \
--gen_training_strategy=cross_entropy \
--seq2seq_share_embedding
```
3. Run MaskGAN in GAN mode:
```bash
python train_mask_gan.py \
--data_dir='/tmp/ptb' \
--batch_size=128 \
--sequence_length=20 \
--base_directory='/tmp/maskGAN' \
--mask_strategy=contiguous \
--maskgan_ckpt='/tmp/maskGAN' \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940" \
--mode='TRAIN' \
--max_steps=100000 \
--generator_model='seq2seq_vd' \
--discriminator_model='seq2seq_vd' \
--is_present_rate=0.5 \
--summaries_every=250 \
--print_every=250 \
--max_num_to_print=3 \
--gen_training_strategy='reinforce' \
--seq2seq_share_embedding=true \
--baseline_method=critic \
--attention_option=luong
```
4. Generate samples:
```bash
python generate_samples.py \
--data_dir /tmp/ptb/ \
--data_set=ptb \
--batch_size=256 \
--sequence_length=20 \
--base_directory /tmp/imdbsample/ \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971" \
--generator_model=seq2seq_vd \
--discriminator_model=seq2seq_vd \
--is_present_rate=0.0 \
--maskgan_ckpt=/tmp/maskGAN \
--seq2seq_share_embedding=True \
--dis_share_embedding=True \
--attention_option=luong \
--mask_strategy=contiguous \
--baseline_method=critic \
--number_epochs=4
```
## Contact for Issues
* Liam Fedus, @liamb315 <liam.fedus@gmail.com>
* Andrew M. Dai, @a-dai <adai@google.com>
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""IMDB data loader and helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import numpy as np
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_boolean('prefix_label', True, 'Vocabulary file.')
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
EOS_INDEX = 88892
def _read_words(filename, use_prefix=True):
all_words = []
sequence_example = tf.train.SequenceExample()
for r in tf.python_io.tf_record_iterator(filename):
sequence_example.ParseFromString(r)
if FLAGS.prefix_label and use_prefix:
label = sequence_example.context.feature['class'].int64_list.value[0]
review_words = [EOS_INDEX + 1 + label]
else:
review_words = []
review_words.extend([
f.int64_list.value[0]
for f in sequence_example.feature_lists.feature_list['token_id'].feature
])
all_words.append(review_words)
return all_words
def build_vocab(vocab_file):
word_to_id = {}
with tf.gfile.GFile(vocab_file, 'r') as f:
index = 0
for word in f:
word_to_id[word.strip()] = index
index += 1
word_to_id['<eos>'] = EOS_INDEX
return word_to_id
def imdb_raw_data(data_path=None):
"""Load IMDB raw data from data directory "data_path".
Reads IMDB tf record files containing integer ids,
and performs mini-batching of the inputs.
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data)
where each of the data objects can be passed to IMDBIterator.
"""
train_path = os.path.join(data_path, 'train_lm.tfrecords')
valid_path = os.path.join(data_path, 'test_lm.tfrecords')
train_data = _read_words(train_path)
valid_data = _read_words(valid_path)
return train_data, valid_data
def imdb_iterator(raw_data, batch_size, num_steps, epoch_size_override=None):
"""Iterate on the raw IMDB data.
This generates batch_size pointers into the raw IMDB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from imdb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one. The third is a set of weights with 1 indicating a word was
present and 0 not.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
del epoch_size_override
data_len = len(raw_data)
num_batches = data_len // batch_size - 1
for batch in range(num_batches):
x = np.zeros([batch_size, num_steps], dtype=np.int32)
y = np.zeros([batch_size, num_steps], dtype=np.int32)
w = np.zeros([batch_size, num_steps], dtype=np.float)
for i in range(batch_size):
data_index = batch * batch_size + i
example = raw_data[data_index]
if len(example) > num_steps:
final_x = example[:num_steps]
final_y = example[1:(num_steps + 1)]
w[i] = 1
else:
to_fill_in = num_steps - len(example)
final_x = example + [EOS_INDEX] * to_fill_in
final_y = final_x[1:] + [EOS_INDEX]
w[i] = [1] * len(example) + [0] * to_fill_in
x[i] = final_x
y[i] = final_y
yield (x, y, w)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PTB data loader and helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
# Dependency imports
import numpy as np
import tensorflow as tf
EOS_INDEX = 0
def _read_words(filename):
with tf.gfile.GFile(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
print("<eos>:", word_to_id["<eos>"])
global EOS_INDEX
EOS_INDEX = word_to_id["<eos>"]
return word_to_id
def _file_to_word_ids(filename, word_to_id):
data = _read_words(filename)
return [word_to_id[word] for word in data if word in word_to_id]
def ptb_raw_data(data_path=None):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
train_path = os.path.join(data_path, "ptb.train.txt")
valid_path = os.path.join(data_path, "ptb.valid.txt")
test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id = build_vocab(train_path)
train_data = _file_to_word_ids(train_path, word_to_id)
valid_data = _file_to_word_ids(valid_path, word_to_id)
test_data = _file_to_word_ids(test_path, word_to_id)
vocabulary = len(word_to_id)
return train_data, valid_data, test_data, vocabulary
def ptb_iterator(raw_data, batch_size, num_steps, epoch_size_override=None):
"""Iterate on the raw PTB data.
This generates batch_size pointers into the raw PTB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from ptb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
raw_data = np.array(raw_data, dtype=np.int32)
data_len = len(raw_data)
batch_len = data_len // batch_size
data = np.full([batch_size, batch_len], EOS_INDEX, dtype=np.int32)
for i in range(batch_size):
data[i] = raw_data[batch_len * i:batch_len * (i + 1)]
if epoch_size_override:
epoch_size = epoch_size_override
else:
epoch_size = (batch_len - 1) // num_steps
if epoch_size == 0:
raise ValueError("epoch_size == 0, decrease batch_size or num_steps")
# print("Number of batches per epoch: %d" % epoch_size)
for i in range(epoch_size):
x = data[:, i * num_steps:(i + 1) * num_steps]
y = data[:, i * num_steps + 1:(i + 1) * num_steps + 1]
w = np.ones_like(x)
yield (x, y, w)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate samples from the MaskGAN.
Launch command:
python generate_samples.py
--data_dir=/tmp/data/imdb --data_set=imdb
--batch_size=256 --sequence_length=20 --base_directory=/tmp/imdb
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,
gen_vd_keep_prob=1.0" --generator_model=seq2seq_vd
--discriminator_model=seq2seq_vd --is_present_rate=0.5
--maskgan_ckpt=/tmp/model.ckpt-45494
--seq2seq_share_embedding=True --dis_share_embedding=True
--attention_option=luong --mask_strategy=contiguous --baseline_method=critic
--number_epochs=4
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import os
# Dependency imports
import numpy as np
import tensorflow as tf
import train_mask_gan
from data import imdb_loader
from data import ptb_loader
# Data.
from model_utils import helper
from model_utils import model_utils
SAMPLE_TRAIN = 'TRAIN'
SAMPLE_VALIDATION = 'VALIDATION'
## Sample Generation.
## Binary and setup FLAGS.
tf.app.flags.DEFINE_enum('sample_mode', 'TRAIN',
[SAMPLE_TRAIN, SAMPLE_VALIDATION],
'Dataset to sample from.')
tf.app.flags.DEFINE_string('output_path', '/tmp', 'Model output directory.')
tf.app.flags.DEFINE_boolean(
'output_masked_logs', False,
'Whether to display for human evaluation (show masking).')
tf.app.flags.DEFINE_integer('number_epochs', 1,
'The number of epochs to produce.')
FLAGS = tf.app.flags.FLAGS
def get_iterator(data):
"""Return the data iterator."""
if FLAGS.data_set == 'ptb':
iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length,
FLAGS.epoch_size_override)
elif FLAGS.data_set == 'imdb':
iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size,
FLAGS.sequence_length)
return iterator
def convert_to_human_readable(id_to_word, arr, p, max_num_to_print):
"""Convert a np.array of indices into words using id_to_word dictionary.
Return max_num_to_print results.
"""
assert arr.ndim == 2
samples = []
for sequence_id in xrange(min(len(arr), max_num_to_print)):
sample = []
for i, index in enumerate(arr[sequence_id, :]):
if p[sequence_id, i] == 1:
sample.append(str(id_to_word[index]))
else:
sample.append('*' + str(id_to_word[index]))
buffer_str = ' '.join(sample)
samples.append(buffer_str)
return samples
def write_unmasked_log(log, id_to_word, sequence_eval):
"""Helper function for logging evaluated sequences without mask."""
indices_arr = np.asarray(sequence_eval)
samples = helper.convert_to_human_readable(id_to_word, indices_arr,
FLAGS.batch_size)
for sample in samples:
log.write(sample + '\n')
log.flush()
return samples
def write_masked_log(log, id_to_word, sequence_eval, present_eval):
indices_arr = np.asarray(sequence_eval)
samples = convert_to_human_readable(id_to_word, indices_arr, present_eval,
FLAGS.batch_size)
for sample in samples:
log.write(sample + '\n')
log.flush()
return samples
def generate_logs(sess, model, log, id_to_word, feed):
"""Impute Sequences using the model for a particular feed and send it to
logs.
"""
# Impute Sequences.
[p, inputs_eval, sequence_eval] = sess.run(
[model.present, model.inputs, model.fake_sequence], feed_dict=feed)
# Add the 0th time-step for coherence.
first_token = np.expand_dims(inputs_eval[:, 0], axis=1)
sequence_eval = np.concatenate((first_token, sequence_eval), axis=1)
# 0th token always present.
p = np.concatenate((np.ones((FLAGS.batch_size, 1)), p), axis=1)
if FLAGS.output_masked_logs:
samples = write_masked_log(log, id_to_word, sequence_eval, p)
else:
samples = write_unmasked_log(log, id_to_word, sequence_eval)
return samples
def generate_samples(hparams, data, id_to_word, log_dir, output_file):
""""Generate samples.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
id_to_word: Dictionary of indices to words.
log_dir: Log directory.
output_file: Output file for the samples.
"""
# Boolean indicating operational mode.
is_training = False
# Set a random seed to keep fixed mask.
np.random.seed(0)
with tf.Graph().as_default():
# Construct the model.
model = train_mask_gan.create_MaskGAN(hparams, is_training)
## Retrieve the initial savers.
init_savers = model_utils.retrieve_init_savers(hparams)
## Initial saver function to supervisor.
init_fn = partial(model_utils.init_fn, init_savers)
is_chief = FLAGS.task == 0
# Create the supervisor. It will take care of initialization, summaries,
# checkpoints, and recovery.
sv = tf.Supervisor(
logdir=log_dir,
is_chief=is_chief,
saver=model.saver,
global_step=model.global_step,
recovery_wait_secs=30,
summary_op=None,
init_fn=init_fn)
# Get an initialized, and possibly recovered session. Launch the
# services: Checkpointing, Summaries, step counting.
#
# When multiple replicas of this program are running the services are
# only launched by the 'chief' replica.
with sv.managed_session(
FLAGS.master, start_standard_services=False) as sess:
# Generator statefulness over the epoch.
[gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
[model.eval_initial_state, model.fake_gen_initial_state])
for n in xrange(FLAGS.number_epochs):
print('Epoch number: %d' % n)
# print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs))
iterator = get_iterator(data)
for x, y, _ in iterator:
if FLAGS.eval_language_model:
is_present_rate = 0.
else:
is_present_rate = FLAGS.is_present_rate
tf.logging.info(
'Evaluating on is_present_rate=%.3f.' % is_present_rate)
model_utils.assign_percent_real(sess, model.percent_real_update,
model.new_rate, is_present_rate)
# Randomly mask out tokens.
p = model_utils.generate_mask()
eval_feed = {model.inputs: x, model.targets: y, model.present: p}
if FLAGS.data_set == 'ptb':
# Statefulness for *evaluation* Generator.
for i, (c, h) in enumerate(model.eval_initial_state):
eval_feed[c] = gen_initial_state_eval[i].c
eval_feed[h] = gen_initial_state_eval[i].h
# Statefulness for the Generator.
for i, (c, h) in enumerate(model.fake_gen_initial_state):
eval_feed[c] = fake_gen_initial_state_eval[i].c
eval_feed[h] = fake_gen_initial_state_eval[i].h
[gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run(
[
model.eval_final_state, model.fake_gen_final_state,
model.global_step
],
feed_dict=eval_feed)
generate_logs(sess, model, output_file, id_to_word, eval_feed)
output_file.close()
print('Closing output_file.')
return
def main(_):
hparams = train_mask_gan.create_hparams()
log_dir = FLAGS.base_directory
tf.gfile.MakeDirs(FLAGS.output_path)
output_file = tf.gfile.GFile(
os.path.join(FLAGS.output_path, 'reviews.txt'), mode='w')
# Load data set.
if FLAGS.data_set == 'ptb':
raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
train_data, valid_data, _, _ = raw_data
elif FLAGS.data_set == 'imdb':
raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
train_data, valid_data = raw_data
else:
raise NotImplementedError
# Generating more data on train set.
if FLAGS.sample_mode == SAMPLE_TRAIN:
data_set = train_data
elif FLAGS.sample_mode == SAMPLE_VALIDATION:
data_set = valid_data
else:
raise NotImplementedError
# Dictionary and reverse dictionry.
if FLAGS.data_set == 'ptb':
word_to_id = ptb_loader.build_vocab(
os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
elif FLAGS.data_set == 'imdb':
word_to_id = imdb_loader.build_vocab(
os.path.join(FLAGS.data_dir, 'vocab.txt'))
id_to_word = {v: k for k, v in word_to_id.iteritems()}
FLAGS.vocab_size = len(id_to_word)
print('Vocab size: %d' % FLAGS.vocab_size)
generate_samples(hparams, data_set, id_to_word, log_dir, output_file)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses for Generator and Discriminator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def discriminator_loss(predictions, labels, missing_tokens):
"""Discriminator loss based on predictions and labels.
Args:
predictions: Discriminator linear predictions Tensor of shape [batch_size,
sequence_length]
labels: Labels for predictions, Tensor of shape [batch_size,
sequence_length]
missing_tokens: Indicator for the missing tokens. Evaluate the loss only
on the tokens that were missing.
Returns:
loss: Scalar tf.float32 loss.
"""
loss = tf.losses.sigmoid_cross_entropy(labels,
predictions,
weights=missing_tokens)
loss = tf.Print(
loss, [loss, labels, missing_tokens],
message='loss, labels, missing_tokens',
summarize=25,
first_n=25)
return loss
def cross_entropy_loss_matrix(gen_labels, gen_logits):
"""Computes the cross entropy loss for G.
Args:
gen_labels: Labels for the correct token.
gen_logits: Generator logits.
Returns:
loss_matrix: Loss matrix of shape [batch_size, sequence_length].
"""
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=gen_labels, logits=gen_logits)
return cross_entropy_loss
def GAN_loss_matrix(dis_predictions):
"""Computes the cross entropy loss for G.
Args:
dis_predictions: Discriminator predictions.
Returns:
loss_matrix: Loss matrix of shape [batch_size, sequence_length].
"""
eps = tf.constant(1e-7, tf.float32)
gan_loss_matrix = -tf.log(dis_predictions + eps)
return gan_loss_matrix
def generator_GAN_loss(predictions):
"""Generator GAN loss based on Discriminator predictions."""
return -tf.log(tf.reduce_mean(predictions))
def generator_blended_forward_loss(gen_logits, gen_labels, dis_predictions,
is_real_input):
"""Computes the masked-loss for G. This will be a blend of cross-entropy
loss where the true label is known and GAN loss where the true label has been
masked.
Args:
gen_logits: Generator logits.
gen_labels: Labels for the correct token.
dis_predictions: Discriminator predictions.
is_real_input: Tensor indicating whether the label is present.
Returns:
loss: Scalar tf.float32 total loss.
"""
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=gen_labels, logits=gen_logits)
gan_loss = -tf.log(dis_predictions)
loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
return tf.reduce_mean(loss_matrix)
def wasserstein_generator_loss(gen_logits, gen_labels, dis_values,
is_real_input):
"""Computes the masked-loss for G. This will be a blend of cross-entropy
loss where the true label is known and GAN loss where the true label is
missing.
Args:
gen_logits: Generator logits.
gen_labels: Labels for the correct token.
dis_values: Discriminator values Tensor of shape [batch_size,
sequence_length].
is_real_input: Tensor indicating whether the label is present.
Returns:
loss: Scalar tf.float32 total loss.
"""
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=gen_labels, logits=gen_logits)
# Maximize the dis_values (minimize the negative)
gan_loss = -dis_values
loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
loss = tf.reduce_mean(loss_matrix)
return loss
def wasserstein_discriminator_loss(real_values, fake_values):
"""Wasserstein discriminator loss.
Args:
real_values: Value given by the Wasserstein Discriminator to real data.
fake_values: Value given by the Wasserstein Discriminator to fake data.
Returns:
loss: Scalar tf.float32 loss.
"""
real_avg = tf.reduce_mean(real_values)
fake_avg = tf.reduce_mean(fake_values)
wasserstein_loss = real_avg - fake_avg
return wasserstein_loss
def wasserstein_discriminator_loss_intrabatch(values, is_real_input):
"""Wasserstein discriminator loss. This is an odd variant where the value
difference is between the real tokens and the fake tokens within a single
batch.
Args:
values: Value given by the Wasserstein Discriminator of shape [batch_size,
sequence_length] to an imputed batch (real and fake).
is_real_input: tf.bool Tensor of shape [batch_size, sequence_length]. If
true, it indicates that the label is known.
Returns:
wasserstein_loss: Scalar tf.float32 loss.
"""
zero_tensor = tf.constant(0., dtype=tf.float32, shape=[])
present = tf.cast(is_real_input, tf.float32)
missing = tf.cast(1 - present, tf.float32)
# Counts for real and fake tokens.
real_count = tf.reduce_sum(present)
fake_count = tf.reduce_sum(missing)
# Averages for real and fake token values.
real = tf.mul(values, present)
fake = tf.mul(values, missing)
real_avg = tf.reduce_sum(real) / real_count
fake_avg = tf.reduce_sum(fake) / fake_count
# If there are no real or fake entries in the batch, we assign an average
# value of zero.
real_avg = tf.where(tf.equal(real_count, 0), zero_tensor, real_avg)
fake_avg = tf.where(tf.equal(fake_count, 0), zero_tensor, fake_avg)
wasserstein_loss = real_avg - fake_avg
return wasserstein_loss
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Random helper functions for converting between indices and one-hot encodings
as well as printing/logging helpers.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def variable_summaries(var, name):
"""Attach a lot of summaries to a Tensor."""
mean = tf.reduce_mean(var)
tf.summary.scalar('mean/' + name, mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
tf.summary.scalar('sttdev/' + name, stddev)
tf.summary.scalar('max/' + name, tf.reduce_max(var))
tf.summary.scalar('min/' + name, tf.reduce_min(var))
tf.summary.histogram(name, var)
def zip_seq_pred_crossent(id_to_word, sequences, predictions, cross_entropy):
"""Zip together the sequences, predictions, cross entropy."""
indices = convert_to_indices(sequences)
batch_of_metrics = []
for ind_batch, pred_batch, crossent_batch in zip(indices, predictions,
cross_entropy):
metrics = []
for index, pred, crossent in zip(ind_batch, pred_batch, crossent_batch):
metrics.append([str(id_to_word[index]), pred, crossent])
batch_of_metrics.append(metrics)
return batch_of_metrics
def print_and_log(log, id_to_word, sequence_eval, max_num_to_print=5):
"""Helper function for printing and logging evaluated sequences."""
indices_eval = convert_to_indices(sequence_eval)
indices_arr = np.asarray(indices_eval)
samples = convert_to_human_readable(id_to_word, indices_arr, max_num_to_print)
for i, sample in enumerate(samples):
print('Sample', i, '. ', sample)
log.write('\nSample ' + str(i) + '. ' + sample)
log.write('\n')
print('\n')
log.flush()
def convert_to_human_readable(id_to_word, arr, max_num_to_print):
"""Convert a np.array of indices into words using id_to_word dictionary.
Return max_num_to_print results.
"""
assert arr.ndim == 2
samples = []
for sequence_id in xrange(min(len(arr), max_num_to_print)):
buffer_str = ' '.join(
[str(id_to_word[index]) for index in arr[sequence_id, :]])
samples.append(buffer_str)
return samples
def index_to_vocab_array(indices, vocab_size, sequence_length):
"""Convert the indices into an array with vocab_size one-hot encoding."""
# Extract properties of the indices.
num_batches = len(indices)
shape = list(indices.shape)
shape.append(vocab_size)
# Construct the vocab_size array.
new_arr = np.zeros(shape)
for n in xrange(num_batches):
indices_batch = indices[n]
new_arr_batch = new_arr[n]
# We map all indices greater than the vocabulary size to an unknown
# character.
indices_batch = np.where(indices_batch < vocab_size, indices_batch,
vocab_size - 1)
# Convert indices to vocab_size dimensions.
new_arr_batch[np.arange(sequence_length), indices_batch] = 1
return new_arr
def convert_to_indices(sequences):
"""Convert a list of size [batch_size, sequence_length, vocab_size] to
a list of size [batch_size, sequence_length] where the vocab element is
denoted by the index.
"""
batch_of_indices = []
for sequence in sequences:
indices = []
for embedding in sequence:
indices.append(np.argmax(embedding))
batch_of_indices.append(indices)
return batch_of_indices
def convert_and_zip(id_to_word, sequences, predictions):
"""Helper function for printing or logging. Retrieves list of sequences
and predictions and zips them together.
"""
indices = convert_to_indices(sequences)
batch_of_indices_predictions = []
for index_batch, pred_batch in zip(indices, predictions):
indices_predictions = []
for index, pred in zip(index_batch, pred_batch):
indices_predictions.append([str(id_to_word[index]), pred])
batch_of_indices_predictions.append(indices_predictions)
return batch_of_indices_predictions
def recursive_length(item):
"""Recursively determine the total number of elements in nested list."""
if type(item) == list:
return sum(recursive_length(subitem) for subitem in item)
else:
return 1.
def percent_correct(real_sequence, fake_sequences):
"""Determine the percent of tokens correctly generated within a batch."""
identical = 0.
for fake_sequence in fake_sequences:
for real, fake in zip(real_sequence, fake_sequence):
if real == fake:
identical += 1.
return identical / recursive_length(fake_sequences)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model construction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from models import bidirectional
from models import bidirectional_vd
from models import bidirectional_zaremba
from models import cnn
from models import critic_vd
from models import feedforward
from models import rnn
from models import rnn_nas
from models import rnn_vd
from models import rnn_zaremba
from models import seq2seq
from models import seq2seq_nas
from models import seq2seq_vd
from models import seq2seq_zaremba
FLAGS = tf.app.flags.FLAGS
# TODO(adai): IMDB labels placeholder to model.
def create_generator(hparams,
inputs,
targets,
present,
is_training,
is_validating,
reuse=None):
"""Create the Generator model specified by the FLAGS and hparams.
Args;
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of the sequence input of shape [batch_size,
sequence_length].
present: tf.bool Tensor indicating the presence or absence of the token
of shape [batch_size, sequence_length].
is_training: Whether the model is training.
is_validating: Whether the model is being run in validation mode for
calculating the perplexity.
reuse (Optional): Whether to reuse the model.
Returns:
Tuple of the (sequence, logits, log_probs) of the Generator. Sequence
and logits have shape [batch_size, sequence_length, vocab_size]. The
log_probs will have shape [batch_size, sequence_length]. Log_probs
corresponds to the log probability of selecting the words.
"""
if FLAGS.generator_model == 'rnn':
(sequence, logits, log_probs, initial_state, final_state) = rnn.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'rnn_zaremba':
(sequence, logits, log_probs, initial_state,
final_state) = rnn_zaremba.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_zaremba':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq_zaremba.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'rnn_nas':
(sequence, logits, log_probs, initial_state,
final_state) = rnn_nas.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_nas':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq_nas.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_vd':
(sequence, logits, log_probs, initial_state, final_state,
encoder_states) = seq2seq_vd.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
else:
raise NotImplementedError
return (sequence, logits, log_probs, initial_state, final_state,
encoder_states)
def create_discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None,
inputs=None,
present=None):
"""Create the Discriminator model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
predictions: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if FLAGS.discriminator_model == 'cnn':
predictions = cnn.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'fnn':
predictions = feedforward.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn':
predictions = rnn.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'bidirectional':
predictions = bidirectional.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'bidirectional_zaremba':
predictions = bidirectional_zaremba.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'seq2seq_vd':
predictions = seq2seq_vd.discriminator(
hparams,
inputs,
present,
sequence,
is_training=is_training,
reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_zaremba':
predictions = rnn_zaremba.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_nas':
predictions = rnn_nas.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_vd':
predictions = rnn_vd.discriminator(
hparams,
sequence,
is_training=is_training,
reuse=reuse,
initial_state=initial_state)
elif FLAGS.discriminator_model == 'bidirectional_vd':
predictions = bidirectional_vd.discriminator(
hparams,
sequence,
is_training=is_training,
reuse=reuse,
initial_state=initial_state)
else:
raise NotImplementedError
return predictions
def create_critic(hparams, sequence, is_training, reuse=None):
"""Create the Critic model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
values: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if FLAGS.baseline_method == 'critic':
if FLAGS.discriminator_model == 'seq2seq_vd':
values = critic_vd.critic_seq2seq_vd_derivative(
hparams, sequence, is_training, reuse=reuse)
else:
raise NotImplementedError
else:
raise NotImplementedError
return values
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model loss construction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
# Useful for REINFORCE baseline.
from losses import losses
FLAGS = tf.app.flags.FLAGS
def create_dis_loss(fake_predictions, real_predictions, targets_present):
"""Compute Discriminator loss across real/fake."""
missing = tf.cast(targets_present, tf.int32)
missing = 1 - missing
missing = tf.cast(missing, tf.bool)
real_labels = tf.ones([FLAGS.batch_size, FLAGS.sequence_length])
dis_loss_real = tf.losses.sigmoid_cross_entropy(
real_labels, real_predictions, weights=missing)
dis_loss_fake = tf.losses.sigmoid_cross_entropy(
targets_present, fake_predictions, weights=missing)
dis_loss = (dis_loss_fake + dis_loss_real) / 2.
return dis_loss, dis_loss_fake, dis_loss_real
def create_critic_loss(cumulative_rewards, estimated_values, present):
"""Compute Critic loss in estimating the value function. This should be an
estimate only for the missing elements."""
missing = tf.cast(present, tf.int32)
missing = 1 - missing
missing = tf.cast(missing, tf.bool)
loss = tf.losses.mean_squared_error(
labels=cumulative_rewards, predictions=estimated_values, weights=missing)
return loss
def create_masked_cross_entropy_loss(targets, present, logits):
"""Calculate the cross entropy loss matrices for the masked tokens."""
cross_entropy_losses = losses.cross_entropy_loss_matrix(targets, logits)
# Zeros matrix.
zeros_losses = tf.zeros(
shape=[FLAGS.batch_size, FLAGS.sequence_length], dtype=tf.float32)
missing_ce_loss = tf.where(present, zeros_losses, cross_entropy_losses)
return missing_ce_loss
def calculate_reinforce_objective(hparams,
log_probs,
dis_predictions,
present,
estimated_values=None):
"""Calculate the REINFORCE objectives. The REINFORCE objective should
only be on the tokens that were missing. Specifically, the final Generator
reward should be based on the Discriminator predictions on missing tokens.
The log probaibilities should be only for missing tokens and the baseline
should be calculated only on the missing tokens.
For this model, we optimize the reward is the log of the *conditional*
probability the Discriminator assigns to the distribution. Specifically, for
a Discriminator D which outputs probability of real, given the past context,
r_t = log D(x_t|x_0,x_1,...x_{t-1})
And the policy for Generator G is the log-probability of taking action x2
given the past context.
Args:
hparams: MaskGAN hyperparameters.
log_probs: tf.float32 Tensor of log probailities of the tokens selected by
the Generator. Shape [batch_size, sequence_length].
dis_predictions: tf.float32 Tensor of the predictions from the
Discriminator. Shape [batch_size, sequence_length].
present: tf.bool Tensor indicating which tokens are present. Shape
[batch_size, sequence_length].
estimated_values: tf.float32 Tensor of estimated state values of tokens.
Shape [batch_size, sequence_length]
Returns:
final_gen_objective: Final REINFORCE objective for the sequence.
rewards: tf.float32 Tensor of rewards for sequence of shape [batch_size,
sequence_length]
advantages: tf.float32 Tensor of advantages for sequence of shape
[batch_size, sequence_length]
baselines: tf.float32 Tensor of baselines for sequence of shape
[batch_size, sequence_length]
maintain_averages_op: ExponentialMovingAverage apply average op to
maintain the baseline.
"""
# Final Generator objective.
final_gen_objective = 0.
gamma = hparams.rl_discount_rate
eps = 1e-7
# Generator rewards are log-probabilities.
eps = tf.constant(1e-7, tf.float32)
dis_predictions = tf.nn.sigmoid(dis_predictions)
rewards = tf.log(dis_predictions + eps)
# Apply only for missing elements.
zeros = tf.zeros_like(present, dtype=tf.float32)
log_probs = tf.where(present, zeros, log_probs)
rewards = tf.where(present, zeros, rewards)
# Unstack Tensors into lists.
rewards_list = tf.unstack(rewards, axis=1)
log_probs_list = tf.unstack(log_probs, axis=1)
missing = 1. - tf.cast(present, tf.float32)
missing_list = tf.unstack(missing, axis=1)
# Cumulative Discounted Returns. The true value function V*(s).
cumulative_rewards = []
for t in xrange(FLAGS.sequence_length):
cum_value = tf.zeros(shape=[FLAGS.batch_size])
for s in xrange(t, FLAGS.sequence_length):
cum_value += missing_list[s] * np.power(gamma, (s - t)) * rewards_list[s]
cumulative_rewards.append(cum_value)
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
## REINFORCE with different baselines.
# We create a separate critic functionality for the Discriminator. This
# will need to operate unidirectionally and it may take in the past context.
if FLAGS.baseline_method == 'critic':
# Critic loss calculated from the estimated value function \hat{V}(s)
# versus the true value function V*(s).
critic_loss = create_critic_loss(cumulative_rewards, estimated_values,
present)
# Baselines are coming from the critic's estimated state values.
baselines = tf.unstack(estimated_values, axis=1)
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages = []
for t in xrange(FLAGS.sequence_length):
log_probability = log_probs_list[t]
cum_advantage = tf.zeros(shape=[FLAGS.batch_size])
for s in xrange(t, FLAGS.sequence_length):
cum_advantage += missing_list[s] * np.power(gamma,
(s - t)) * rewards_list[s]
cum_advantage -= baselines[t]
# Clip advantages.
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping,
FLAGS.advantage_clipping)
advantages.append(missing_list[t] * cum_advantage)
final_gen_objective += tf.multiply(
log_probability, missing_list[t] * tf.stop_gradient(cum_advantage))
maintain_averages_op = None
baselines = tf.stack(baselines, axis=1)
advantages = tf.stack(advantages, axis=1)
# Split the batch into half. Use half for MC estimates for REINFORCE.
# Use the other half to establish a baseline.
elif FLAGS.baseline_method == 'dis_batch':
# TODO(liamfedus): Recheck.
[rewards_half, baseline_half] = tf.split(
rewards, num_or_size_splits=2, axis=0)
[log_probs_half, _] = tf.split(log_probs, num_or_size_splits=2, axis=0)
[reward_present_half, baseline_present_half] = tf.split(
present, num_or_size_splits=2, axis=0)
# Unstack to lists.
baseline_list = tf.unstack(baseline_half, axis=1)
baseline_missing = 1. - tf.cast(baseline_present_half, tf.float32)
baseline_missing_list = tf.unstack(baseline_missing, axis=1)
baselines = []
for t in xrange(FLAGS.sequence_length):
# Calculate baseline only for missing tokens.
num_missing = tf.reduce_sum(baseline_missing_list[t])
avg_baseline = tf.reduce_sum(
baseline_missing_list[t] * baseline_list[t], keep_dims=True) / (
num_missing + eps)
baseline = tf.tile(avg_baseline, multiples=[FLAGS.batch_size / 2])
baselines.append(baseline)
# Unstack to lists.
rewards_list = tf.unstack(rewards_half, axis=1)
log_probs_list = tf.unstack(log_probs_half, axis=1)
reward_missing = 1. - tf.cast(reward_present_half, tf.float32)
reward_missing_list = tf.unstack(reward_missing, axis=1)
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages = []
for t in xrange(FLAGS.sequence_length):
log_probability = log_probs_list[t]
cum_advantage = tf.zeros(shape=[FLAGS.batch_size / 2])
for s in xrange(t, FLAGS.sequence_length):
cum_advantage += reward_missing_list[s] * np.power(gamma, (s - t)) * (
rewards_list[s] - baselines[s])
# Clip advantages.
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping,
FLAGS.advantage_clipping)
advantages.append(reward_missing_list[t] * cum_advantage)
final_gen_objective += tf.multiply(
log_probability,
reward_missing_list[t] * tf.stop_gradient(cum_advantage))
# Cumulative Discounted Returns. The true value function V*(s).
cumulative_rewards = []
for t in xrange(FLAGS.sequence_length):
cum_value = tf.zeros(shape=[FLAGS.batch_size / 2])
for s in xrange(t, FLAGS.sequence_length):
cum_value += reward_missing_list[s] * np.power(gamma, (
s - t)) * rewards_list[s]
cumulative_rewards.append(cum_value)
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
rewards = rewards_half
critic_loss = None
maintain_averages_op = None
baselines = tf.stack(baselines, axis=1)
advantages = tf.stack(advantages, axis=1)
# Exponential Moving Average baseline.
elif FLAGS.baseline_method == 'ema':
# TODO(liamfedus): Recheck.
# Lists of rewards and Log probabilities of the actions taken only for
# missing tokens.
ema = tf.train.ExponentialMovingAverage(decay=hparams.baseline_decay)
maintain_averages_op = ema.apply(rewards_list)
baselines = []
for r in rewards_list:
baselines.append(ema.average(r))
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages = []
for t in xrange(FLAGS.sequence_length):
log_probability = log_probs_list[t]
# Calculate the forward advantage only on the missing tokens.
cum_advantage = tf.zeros(shape=[FLAGS.batch_size])
for s in xrange(t, FLAGS.sequence_length):
cum_advantage += missing_list[s] * np.power(gamma, (s - t)) * (
rewards_list[s] - baselines[s])
# Clip advantages.
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping,
FLAGS.advantage_clipping)
advantages.append(missing_list[t] * cum_advantage)
final_gen_objective += tf.multiply(
log_probability, missing_list[t] * tf.stop_gradient(cum_advantage))
critic_loss = None
baselines = tf.stack(baselines, axis=1)
advantages = tf.stack(advantages, axis=1)
elif FLAGS.baseline_method is None:
num_missing = tf.reduce_sum(missing)
final_gen_objective += tf.reduce_sum(rewards) / (num_missing + eps)
baselines = tf.zeros_like(rewards)
critic_loss = None
maintain_averages_op = None
advantages = cumulative_rewards
else:
raise NotImplementedError
return [
final_gen_objective, log_probs, rewards, advantages, baselines,
maintain_averages_op, critic_loss, cumulative_rewards
]
def calculate_log_perplexity(logits, targets, present):
"""Calculate the average log perplexity per *missing* token.
Args:
logits: tf.float32 Tensor of the logits of shape [batch_size,
sequence_length, vocab_size].
targets: tf.int32 Tensor of the sequence target of shape [batch_size,
sequence_length].
present: tf.bool Tensor indicating the presence or absence of the token
of shape [batch_size, sequence_length].
Returns:
avg_log_perplexity: Scalar indicating the average log perplexity per
missing token in the batch.
"""
# logits = tf.Print(logits, [logits], message='logits:', summarize=50)
# targets = tf.Print(targets, [targets], message='targets:', summarize=50)
eps = 1e-12
logits = tf.reshape(logits, [-1, FLAGS.vocab_size])
# Only calculate log-perplexity on missing tokens.
weights = tf.cast(present, tf.float32)
weights = 1. - weights
weights = tf.reshape(weights, [-1])
num_missing = tf.reduce_sum(weights)
log_perplexity = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
[logits], [tf.reshape(targets, [-1])], [weights])
avg_log_perplexity = tf.reduce_sum(log_perplexity) / (num_missing + eps)
return avg_log_perplexity
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def create_dis_pretrain_op(hparams, dis_loss, global_step):
"""Create a train op for pretraining."""
with tf.name_scope('pretrain_generator'):
optimizer = tf.train.AdamOptimizer(hparams.dis_pretrain_learning_rate)
dis_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('dis')
]
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding:
shared_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
dis_vars.append(shared_embedding)
dis_grads = tf.gradients(dis_loss, dis_vars)
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads,
FLAGS.grad_clipping)
dis_pretrain_op = optimizer.apply_gradients(
zip(dis_grads_clipped, dis_vars), global_step=global_step)
return dis_pretrain_op
def create_gen_pretrain_op(hparams, cross_entropy_loss, global_step):
"""Create a train op for pretraining."""
with tf.name_scope('pretrain_generator'):
optimizer = tf.train.AdamOptimizer(hparams.gen_pretrain_learning_rate)
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
gen_grads = tf.gradients(cross_entropy_loss, gen_vars)
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
gen_pretrain_op = optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
return gen_pretrain_op
def create_gen_train_op(hparams, learning_rate, gen_loss, global_step, mode):
"""Create Generator train op."""
del hparams
with tf.name_scope('train_generator'):
if FLAGS.generator_optimizer == 'sgd':
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate)
elif FLAGS.generator_optimizer == 'adam':
gen_optimizer = tf.train.AdamOptimizer(learning_rate)
else:
raise NotImplementedError
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
print('Optimizing Generator vars.')
for v in gen_vars:
print(v)
if mode == 'MINIMIZE':
gen_grads = tf.gradients(gen_loss, gen_vars)
elif mode == 'MAXIMIZE':
gen_grads = tf.gradients(-gen_loss, gen_vars)
else:
raise ValueError("Must be one of 'MINIMIZE' or 'MAXIMIZE'")
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
gen_train_op = gen_optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
return gen_train_op, gen_grads_clipped, gen_vars
def create_reinforce_gen_train_op(hparams, learning_rate, final_gen_reward,
averages_op, global_step):
"""Create the Generator train_op when using REINFORCE.
Args:
hparams: MaskGAN hyperparameters.
learning_rate: tf.Variable scalar learning rate.
final_gen_objective: Scalar final REINFORCE objective for the sequence.
averages_op: ExponentialMovingAverage apply average op to
maintain the baseline.
global_step: global_step tf.Variable.
Returns:
gen_train_op: Generator training op.
"""
del hparams
with tf.name_scope('train_generator'):
if FLAGS.generator_optimizer == 'sgd':
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate)
elif FLAGS.generator_optimizer == 'adam':
gen_optimizer = tf.train.AdamOptimizer(learning_rate)
else:
raise NotImplementedError
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
print('\nOptimizing Generator vars:')
for v in gen_vars:
print(v)
# Maximize reward.
gen_grads = tf.gradients(-final_gen_reward, gen_vars)
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
maximize_op = gen_optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
# Group maintain averages op.
if averages_op:
gen_train_op = tf.group(maximize_op, averages_op)
else:
gen_train_op = maximize_op
return [gen_train_op, gen_grads, gen_vars]
def create_dis_train_op(hparams, dis_loss, global_step):
"""Create Discriminator train op."""
with tf.name_scope('train_discriminator'):
dis_optimizer = tf.train.AdamOptimizer(hparams.dis_learning_rate)
dis_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('dis')
]
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding:
shared_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
dis_vars.append(shared_embedding)
print('\nOptimizing Discriminator vars:')
for v in dis_vars:
print(v)
dis_grads = tf.gradients(dis_loss, dis_vars)
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads,
FLAGS.grad_clipping)
dis_train_op = dis_optimizer.apply_gradients(
zip(dis_grads_clipped, dis_vars), global_step=global_step)
return dis_train_op, dis_grads_clipped, dis_vars
def create_critic_train_op(hparams, critic_loss, global_step):
"""Create Discriminator train op."""
with tf.name_scope('train_critic'):
critic_optimizer = tf.train.AdamOptimizer(hparams.critic_learning_rate)
output_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('critic')
]
if FLAGS.critic_update_dis_vars:
if FLAGS.discriminator_model == 'bidirectional_vd':
critic_vars = [
v for v in tf.trainable_variables()
if v.op.name.startswith('dis/rnn')
]
elif FLAGS.discriminator_model == 'seq2seq_vd':
critic_vars = [
v for v in tf.trainable_variables()
if v.op.name.startswith('dis/decoder/rnn/multi_rnn_cell')
]
critic_vars.extend(output_vars)
else:
critic_vars = output_vars
print('\nOptimizing Critic vars:')
for v in critic_vars:
print(v)
critic_grads = tf.gradients(critic_loss, critic_vars)
critic_grads_clipped, _ = tf.clip_by_global_norm(critic_grads,
FLAGS.grad_clipping)
critic_train_op = critic_optimizer.apply_gradients(
zip(critic_grads_clipped, critic_vars), global_step=global_step)
return critic_train_op, critic_grads_clipped, critic_vars
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from model_utils import variable_mapping
FLAGS = tf.app.flags.FLAGS
def generate_mask():
"""Generate the mask to be fed into the model."""
if FLAGS.mask_strategy == 'random':
p = np.random.choice(
[True, False],
size=[FLAGS.batch_size, FLAGS.sequence_length],
p=[FLAGS.is_present_rate, 1. - FLAGS.is_present_rate])
elif FLAGS.mask_strategy == 'contiguous':
masked_length = int((1 - FLAGS.is_present_rate) * FLAGS.sequence_length) - 1
# Determine location to start masking.
start_mask = np.random.randint(
1, FLAGS.sequence_length - masked_length + 1, size=FLAGS.batch_size)
p = np.full([FLAGS.batch_size, FLAGS.sequence_length], True, dtype=bool)
# Create contiguous masked section to be False.
for i, index in enumerate(start_mask):
p[i, index:index + masked_length] = False
else:
raise NotImplementedError
return p
def assign_percent_real(session, percent_real_update, new_rate, current_rate):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
percent_real_update: tf.assign operation.
new_rate: tf.placeholder for the new rate.
current_rate: Percent of tokens that are currently real. Fake tokens
are the ones being imputed by the Generator.
"""
session.run(percent_real_update, feed_dict={new_rate: current_rate})
def assign_learning_rate(session, lr_update, lr_placeholder, new_lr):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
lr_update: tf.assign operation.
lr_placeholder: tf.placeholder for the new learning rate.
new_lr: New learning rate to use.
"""
session.run(lr_update, feed_dict={lr_placeholder: new_lr})
def clip_weights(variables, c_lower, c_upper):
"""Clip a list of weights to be within a certain range.
Args:
variables: List of tf.Variable weights.
c_lower: Lower bound for weights.
c_upper: Upper bound for weights.
"""
clip_ops = []
for var in variables:
clipped_var = tf.clip_by_value(var, c_lower, c_upper)
clip_ops.append(tf.assign(var, clipped_var))
return tf.group(*clip_ops)
def retrieve_init_savers(hparams):
"""Retrieve a dictionary of all the initial savers for the models.
Args:
hparams: MaskGAN hyperparameters.
"""
## Dictionary of init savers.
init_savers = {}
## Load Generator weights from MaskGAN checkpoint.
if FLAGS.maskgan_ckpt:
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
init_saver = tf.train.Saver(var_list=gen_vars)
init_savers['init_saver'] = init_saver
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if FLAGS.discriminator_model == 'seq2seq_vd':
dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams)
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
## Load weights from language model checkpoint.
if FLAGS.language_model_ckpt_dir:
if FLAGS.maskgan_ckpt is None:
## Generator Variables/Savers.
if FLAGS.generator_model == 'rnn_nas':
gen_variable_maps = variable_mapping.rnn_nas(hparams, model='gen')
gen_init_saver = tf.train.Saver(var_list=gen_variable_maps)
init_savers['gen_init_saver'] = gen_init_saver
elif FLAGS.generator_model == 'seq2seq_nas':
# Encoder.
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas(
hparams)
gen_encoder_init_saver = tf.train.Saver(
var_list=gen_encoder_variable_maps)
# Decoder.
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas(
hparams)
gen_decoder_init_saver = tf.train.Saver(
var_list=gen_decoder_variable_maps)
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver
# seq2seq_vd derived from the same code base as seq2seq_zaremba.
elif (FLAGS.generator_model == 'seq2seq_zaremba' or
FLAGS.generator_model == 'seq2seq_vd'):
# Encoder.
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq(
hparams)
gen_encoder_init_saver = tf.train.Saver(
var_list=gen_encoder_variable_maps)
# Decoder.
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq(
hparams)
gen_decoder_init_saver = tf.train.Saver(
var_list=gen_decoder_variable_maps)
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver
else:
raise NotImplementedError
## Discriminator Variables/Savers.
if FLAGS.discriminator_model == 'rnn_nas':
dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis')
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
# rnn_vd derived from the same code base as rnn_zaremba.
elif (FLAGS.discriminator_model == 'rnn_zaremba' or
FLAGS.discriminator_model == 'rnn_vd'):
dis_variable_maps = variable_mapping.rnn_zaremba(hparams, model='dis')
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd'):
dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(hparams)
dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(hparams)
# Savers for the forward/backward Discriminator components.
dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps)
dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps)
init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver
init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver
elif FLAGS.discriminator_model == 'cnn':
dis_variable_maps = variable_mapping.cnn()
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
elif FLAGS.discriminator_model == 'seq2seq_vd':
# Encoder.
dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(hparams)
dis_encoder_init_saver = tf.train.Saver(
var_list=dis_encoder_variable_maps)
# Decoder.
dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(hparams)
dis_decoder_init_saver = tf.train.Saver(
var_list=dis_decoder_variable_maps)
init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver
init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver
return init_savers
def init_fn(init_savers, sess):
"""The init_fn to be passed to the Supervisor.
Args:
init_savers: Dictionary of init_savers. 'init_saver_name': init_saver.
sess: tf.Session.
"""
## Load Generator weights from MaskGAN checkpoint.
if FLAGS.maskgan_ckpt:
print('Restoring Generator from %s.' % FLAGS.maskgan_ckpt)
tf.logging.info('Restoring Generator from %s.' % FLAGS.maskgan_ckpt)
print('Asserting Generator is a seq2seq-variant.')
tf.logging.info('Asserting Generator is a seq2seq-variant.')
assert FLAGS.generator_model.startswith('seq2seq')
init_saver = init_savers['init_saver']
init_saver.restore(sess, FLAGS.maskgan_ckpt)
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if FLAGS.discriminator_model == 'seq2seq_vd':
print('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt)
dis_init_saver = init_savers['dis_init_saver']
dis_init_saver.restore(sess, FLAGS.maskgan_ckpt)
## Load weights from language model checkpoint.
if FLAGS.language_model_ckpt_dir:
if FLAGS.maskgan_ckpt is None:
## Generator Models.
if FLAGS.generator_model == 'rnn_nas':
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Generator from %s.' % load_ckpt)
tf.logging.info('Restoring Generator from %s.' % load_ckpt)
gen_init_saver = init_savers['gen_init_saver']
gen_init_saver.restore(sess, load_ckpt)
elif FLAGS.generator_model.startswith('seq2seq'):
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Generator from %s.' % load_ckpt)
tf.logging.info('Restoring Generator from %s.' % load_ckpt)
gen_encoder_init_saver = init_savers['gen_encoder_init_saver']
gen_decoder_init_saver = init_savers['gen_decoder_init_saver']
gen_encoder_init_saver.restore(sess, load_ckpt)
gen_decoder_init_saver.restore(sess, load_ckpt)
## Discriminator Models.
if (FLAGS.discriminator_model == 'rnn_nas' or
FLAGS.discriminator_model == 'rnn_zaremba' or
FLAGS.discriminator_model == 'rnn_vd' or
FLAGS.discriminator_model == 'cnn'):
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Discriminator from %s.' % load_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt)
dis_init_saver = init_savers['dis_init_saver']
dis_init_saver.restore(sess, load_ckpt)
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd'):
assert FLAGS.language_model_ckpt_dir_reversed is not None, (
'Need a reversed directory to fill in the backward components.')
load_fwd_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
load_bwd_ckpt = tf.train.latest_checkpoint(
FLAGS.language_model_ckpt_dir_reversed)
print('Restoring Discriminator from %s and %s.' % (load_fwd_ckpt,
load_bwd_ckpt))
tf.logging.info('Restoring Discriminator from %s and %s.' %
(load_fwd_ckpt, load_bwd_ckpt))
dis_fwd_init_saver = init_savers['dis_fwd_init_saver']
dis_bwd_init_saver = init_savers['dis_bwd_init_saver']
dis_fwd_init_saver.restore(sess, load_fwd_ckpt)
dis_bwd_init_saver.restore(sess, load_bwd_ckpt)
elif FLAGS.discriminator_model == 'seq2seq_vd':
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Discriminator from %s.' % load_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt)
dis_encoder_init_saver = init_savers['dis_encoder_init_saver']
dis_decoder_init_saver = init_savers['dis_decoder_init_saver']
dis_encoder_init_saver.restore(sess, load_ckpt)
dis_decoder_init_saver.restore(sess, load_ckpt)
else:
return
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""We calculate n-Grams from the training text. We will use this as an
evaluation metric."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def hash_function(input_tuple):
"""Hash function for a tuple."""
return hash(input_tuple)
def find_all_ngrams(dataset, n):
"""Generate a list of all ngrams."""
return zip(*[dataset[i:] for i in xrange(n)])
def construct_ngrams_dict(ngrams_list):
"""Construct a ngram dictionary which maps an ngram tuple to the number
of times it appears in the text."""
counts = {}
for t in ngrams_list:
key = hash_function(t)
if key in counts:
counts[key] += 1
else:
counts[key] = 1
return counts
def percent_unique_ngrams_in_train(train_ngrams_dict, gen_ngrams_dict):
"""Compute the percent of ngrams generated by the model that are
present in the training text and are unique."""
# *Total* number of n-grams produced by the generator.
total_ngrams_produced = 0
for _, value in gen_ngrams_dict.iteritems():
total_ngrams_produced += value
# The unique ngrams in the training set.
unique_ngrams_in_train = 0.
for key, _ in gen_ngrams_dict.iteritems():
if key in train_ngrams_dict:
unique_ngrams_in_train += 1
return float(unique_ngrams_in_train) / float(total_ngrams_produced)
This diff is collapsed.
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention-based decoder functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import function
__all__ = [
"prepare_attention", "attention_decoder_fn_train",
"attention_decoder_fn_inference"
]
def attention_decoder_fn_train(encoder_state,
attention_keys,
attention_values,
attention_score_fn,
attention_construct_fn,
name=None):
"""Attentional decoder function for `dynamic_rnn_decoder` during training.
The `attention_decoder_fn_train` is a training function for an
attention-based sequence-to-sequence model. It should be used when
`dynamic_rnn_decoder` is in the training mode.
The `attention_decoder_fn_train` is called with a set of the user arguments
and returns the `decoder_fn`, which can be passed to the
`dynamic_rnn_decoder`, such that
```
dynamic_fn_train = attention_decoder_fn_train(encoder_state)
outputs_train, state_train = dynamic_rnn_decoder(
decoder_fn=dynamic_fn_train, ...)
```
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
Args:
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
name: (default: `None`) NameScope for the decoder function;
defaults to "simple_decoder_fn_train"
Returns:
A decoder function with the required interface of `dynamic_rnn_decoder`
intended for training.
"""
with tf.name_scope(name, "attention_decoder_fn_train", [
encoder_state, attention_keys, attention_values, attention_score_fn,
attention_construct_fn
]):
pass
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
"""Decoder function used in the `dynamic_rnn_decoder` for training.
Args:
time: positive integer constant reflecting the current timestep.
cell_state: state of RNNCell.
cell_input: input provided by `dynamic_rnn_decoder`.
cell_output: output of RNNCell.
context_state: context state provided by `dynamic_rnn_decoder`.
Returns:
A tuple (done, next state, next input, emit output, next context state)
where:
done: `None`, which is used by the `dynamic_rnn_decoder` to indicate
that `sequence_lengths` in `dynamic_rnn_decoder` should be used.
next state: `cell_state`, this decoder function does not modify the
given state.
next input: `cell_input`, this decoder function does not modify the
given input. The input could be modified when applying e.g. attention.
emit output: `cell_output`, this decoder function does not modify the
given output.
next context state: `context_state`, this decoder function does not
modify the given context state. The context state could be modified when
applying e.g. beam search.
"""
with tf.name_scope(
name, "attention_decoder_fn_train",
[time, cell_state, cell_input, cell_output, context_state]):
if cell_state is None: # first call, return encoder_state
cell_state = encoder_state
# init attention
attention = _init_attention(encoder_state)
else:
# construct attention
attention = attention_construct_fn(cell_output, attention_keys,
attention_values)
cell_output = attention
# combine cell_input and attention
next_input = tf.concat([cell_input, attention], 1)
return (None, cell_state, next_input, cell_output, context_state)
return decoder_fn
def attention_decoder_fn_inference(output_fn,
encoder_state,
attention_keys,
attention_values,
attention_score_fn,
attention_construct_fn,
embeddings,
start_of_sequence_id,
end_of_sequence_id,
maximum_length,
num_decoder_symbols,
dtype=tf.int32,
name=None):
"""Attentional decoder function for `dynamic_rnn_decoder` during inference.
The `attention_decoder_fn_inference` is a simple inference function for a
sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is
in the inference mode.
The `attention_decoder_fn_inference` is called with user arguments
and returns the `decoder_fn`, which can be passed to the
`dynamic_rnn_decoder`, such that
```
dynamic_fn_inference = attention_decoder_fn_inference(...)
outputs_inference, state_inference = dynamic_rnn_decoder(
decoder_fn=dynamic_fn_inference, ...)
```
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
Args:
output_fn: An output function to project your `cell_output` onto class
logits.
An example of an output function;
```
tf.variable_scope("decoder") as varscope
output_fn = lambda x: tf.contrib.layers.linear(x, num_decoder_symbols,
scope=varscope)
outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...)
logits_train = output_fn(outputs_train)
varscope.reuse_variables()
logits_inference, state_inference = seq2seq.dynamic_rnn_decoder(
output_fn=output_fn, ...)
```
If `None` is supplied it will act as an identity function, which
might be wanted when using the RNNCell `OutputProjectionWrapper`.
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
embeddings: The embeddings matrix used for the decoder sized
`[num_decoder_symbols, embedding_size]`.
start_of_sequence_id: The start of sequence ID in the decoder embeddings.
end_of_sequence_id: The end of sequence ID in the decoder embeddings.
maximum_length: The maximum allowed of time steps to decode.
num_decoder_symbols: The number of classes to decode at each time step.
dtype: (default: `tf.int32`) The default data type to use when
handling integer objects.
name: (default: `None`) NameScope for the decoder function;
defaults to "attention_decoder_fn_inference"
Returns:
A decoder function with the required interface of `dynamic_rnn_decoder`
intended for inference.
"""
with tf.name_scope(name, "attention_decoder_fn_inference", [
output_fn, encoder_state, attention_keys, attention_values,
attention_score_fn, attention_construct_fn, embeddings,
start_of_sequence_id, end_of_sequence_id, maximum_length,
num_decoder_symbols, dtype
]):
start_of_sequence_id = tf.convert_to_tensor(start_of_sequence_id, dtype)
end_of_sequence_id = tf.convert_to_tensor(end_of_sequence_id, dtype)
maximum_length = tf.convert_to_tensor(maximum_length, dtype)
num_decoder_symbols = tf.convert_to_tensor(num_decoder_symbols, dtype)
encoder_info = tf.contrib.framework.nest.flatten(encoder_state)[0]
batch_size = encoder_info.get_shape()[0].value
if output_fn is None:
output_fn = lambda x: x
if batch_size is None:
batch_size = tf.shape(encoder_info)[0]
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
"""Decoder function used in the `dynamic_rnn_decoder` for inference.
The main difference between this decoder function and the `decoder_fn` in
`attention_decoder_fn_train` is how `next_cell_input` is calculated. In
decoder function we calculate the next input by applying an argmax across
the feature dimension of the output from the decoder. This is a
greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
use beam-search instead.
Args:
time: positive integer constant reflecting the current timestep.
cell_state: state of RNNCell.
cell_input: input provided by `dynamic_rnn_decoder`.
cell_output: output of RNNCell.
context_state: context state provided by `dynamic_rnn_decoder`.
Returns:
A tuple (done, next state, next input, emit output, next context state)
where:
done: A boolean vector to indicate which sentences has reached a
`end_of_sequence_id`. This is used for early stopping by the
`dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
all elements as `true` is returned.
next state: `cell_state`, this decoder function does not modify the
given state.
next input: The embedding from argmax of the `cell_output` is used as
`next_input`.
emit output: If `output_fn is None` the supplied `cell_output` is
returned, else the `output_fn` is used to update the `cell_output`
before calculating `next_input` and returning `cell_output`.
next context state: `context_state`, this decoder function does not
modify the given context state. The context state could be modified when
applying e.g. beam search.
Raises:
ValueError: if cell_input is not None.
"""
with tf.name_scope(
name, "attention_decoder_fn_inference",
[time, cell_state, cell_input, cell_output, context_state]):
if cell_input is not None:
raise ValueError(
"Expected cell_input to be None, but saw: %s" % cell_input)
if cell_output is None:
# invariant that this is time == 0
next_input_id = tf.ones(
[
batch_size,
], dtype=dtype) * (
start_of_sequence_id)
done = tf.zeros(
[
batch_size,
], dtype=tf.bool)
cell_state = encoder_state
cell_output = tf.zeros([num_decoder_symbols], dtype=tf.float32)
cell_input = tf.gather(embeddings, next_input_id)
# init attention
attention = _init_attention(encoder_state)
else:
# construct attention
attention = attention_construct_fn(cell_output, attention_keys,
attention_values)
cell_output = attention
# argmax decoder
cell_output = output_fn(cell_output) # logits
next_input_id = tf.cast(tf.argmax(cell_output, 1), dtype=dtype)
done = tf.equal(next_input_id, end_of_sequence_id)
cell_input = tf.gather(embeddings, next_input_id)
# combine cell_input and attention
next_input = tf.concat([cell_input, attention], 1)
# if time > maxlen, return all true vector
done = tf.cond(
tf.greater(time, maximum_length),
lambda: tf.ones([
batch_size,], dtype=tf.bool), lambda: done)
return (done, cell_state, next_input, cell_output, context_state)
return decoder_fn
## Helper functions ##
def prepare_attention(attention_states, attention_option, num_units,
reuse=None):
"""Prepare keys/values/functions for attention.
Args:
attention_states: hidden states to attend over.
attention_option: how to compute attention, either "luong" or "bahdanau".
num_units: hidden state dimension.
reuse: whether to reuse variable scope.
Returns:
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
"""
# Prepare attention keys / values from attention_states
with tf.variable_scope("attention_keys", reuse=reuse) as scope:
attention_keys = tf.contrib.layers.linear(
attention_states, num_units, biases_initializer=None, scope=scope)
attention_values = attention_states
# Attention score function
attention_score_fn = _create_attention_score_fn("attention_score", num_units,
attention_option, reuse)
# Attention construction function
attention_construct_fn = _create_attention_construct_fn(
"attention_construct", num_units, attention_score_fn, reuse)
return (attention_keys, attention_values, attention_score_fn,
attention_construct_fn)
def _init_attention(encoder_state):
"""Initialize attention. Handling both LSTM and GRU.
Args:
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
Returns:
attn: initial zero attention vector.
"""
# Multi- vs single-layer
# TODO(thangluong): is this the best way to check?
if isinstance(encoder_state, tuple):
top_state = encoder_state[-1]
else:
top_state = encoder_state
# LSTM vs GRU
if isinstance(top_state, tf.contrib.rnn.LSTMStateTuple):
attn = tf.zeros_like(top_state.h)
else:
attn = tf.zeros_like(top_state)
return attn
def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse):
"""Function to compute attention vectors.
Args:
name: to label variables.
num_units: hidden state dimension.
attention_score_fn: to compute similarity between key and target states.
reuse: whether to reuse variable scope.
Returns:
attention_construct_fn: to build attention states.
"""
def construct_fn(attention_query, attention_keys, attention_values):
with tf.variable_scope(name, reuse=reuse) as scope:
context = attention_score_fn(attention_query, attention_keys,
attention_values)
concat_input = tf.concat([attention_query, context], 1)
attention = tf.contrib.layers.linear(
concat_input, num_units, biases_initializer=None, scope=scope)
return attention
return construct_fn
# keys: [batch_size, attention_length, attn_size]
# query: [batch_size, 1, attn_size]
# return weights [batch_size, attention_length]
@function.Defun(func_name="attn_add_fun", noinline=True)
def _attn_add_fun(v, keys, query):
return tf.reduce_sum(v * tf.tanh(keys + query), [2])
@function.Defun(func_name="attn_mul_fun", noinline=True)
def _attn_mul_fun(keys, query):
return tf.reduce_sum(keys * query, [2])
def _create_attention_score_fn(name,
num_units,
attention_option,
reuse,
dtype=tf.float32):
"""Different ways to compute attention scores.
Args:
name: to label variables.
num_units: hidden state dimension.
attention_option: how to compute attention, either "luong" or "bahdanau".
"bahdanau": additive (Bahdanau et al., ICLR'2015)
"luong": multiplicative (Luong et al., EMNLP'2015)
reuse: whether to reuse variable scope.
dtype: (default: `tf.float32`) data type to use.
Returns:
attention_score_fn: to compute similarity between key and target states.
"""
with tf.variable_scope(name, reuse=reuse):
if attention_option == "bahdanau":
query_w = tf.get_variable("attnW", [num_units, num_units], dtype=dtype)
score_v = tf.get_variable("attnV", [num_units], dtype=dtype)
def attention_score_fn(query, keys, values):
"""Put attention masks on attention_values using attention_keys and query.
Args:
query: A Tensor of shape [batch_size, num_units].
keys: A Tensor of shape [batch_size, attention_length, num_units].
values: A Tensor of shape [batch_size, attention_length, num_units].
Returns:
context_vector: A Tensor of shape [batch_size, num_units].
Raises:
ValueError: if attention_option is neither "luong" or "bahdanau".
"""
if attention_option == "bahdanau":
# transform query
query = tf.matmul(query, query_w)
# reshape query: [batch_size, 1, num_units]
query = tf.reshape(query, [-1, 1, num_units])
# attn_fun
scores = _attn_add_fun(score_v, keys, query)
elif attention_option == "luong":
# reshape query: [batch_size, 1, num_units]
query = tf.reshape(query, [-1, 1, num_units])
# attn_fun
scores = _attn_mul_fun(keys, query)
else:
raise ValueError("Unknown attention option %s!" % attention_option)
# Compute alignment weights
# scores: [batch_size, length]
# alignments: [batch_size, length]
# TODO(thangluong): not normalize over padding positions.
alignments = tf.nn.softmax(scores)
# Now calculate the attention-weighted vector.
alignments = tf.expand_dims(alignments, 2)
context_vector = tf.reduce_sum(alignments * values, [1])
context_vector.set_shape([None, num_units])
return context_vector
return attention_score_fn
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# ZoneoutWrapper.
from regularization import zoneout
FLAGS = tf.app.flags.FLAGS
def discriminator(hparams, sequence, is_training, reuse=None):
"""Define the bidirectional Discriminator graph."""
sequence = tf.cast(sequence, tf.int32)
if FLAGS.dis_share_embedding:
assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.')
with tf.variable_scope('gen/rnn', reuse=True):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.gen_rnn_size])
with tf.variable_scope('dis', reuse=reuse):
cell_fwd = tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.dis_rnn_size, forget_bias=1.0, reuse=reuse)
cell_bwd = tf.contrib.rnn.LayerNormBasicLSTMCell(
hparams.dis_rnn_size, forget_bias=1.0, reuse=reuse)
if FLAGS.zoneout_drop_prob > 0.0:
cell_fwd = zoneout.ZoneoutWrapper(
cell_fwd,
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
cell_bwd = zoneout.ZoneoutWrapper(
cell_bwd,
zoneout_drop_prob=FLAGS.zoneout_drop_prob,
is_training=is_training)
state_fwd = cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
state_bwd = cell_bwd.zero_state(FLAGS.batch_size, tf.float32)
if not FLAGS.dis_share_embedding:
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, hparams.dis_rnn_size])
rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)
rnn_inputs = tf.unstack(rnn_inputs, axis=1)
with tf.variable_scope('rnn') as vs:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
cell_fwd, cell_bwd, rnn_inputs, state_fwd, state_bwd, scope=vs)
# Prediction is linear output for Discriminator.
predictions = tf.contrib.layers.linear(outputs, 1, scope=vs)
predictions = tf.transpose(predictions, [1, 0, 2])
return tf.squeeze(predictions, axis=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment