Commit f02e6013 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Merge remote-tracking branch 'tensorflow/master'

parents f5f1e12a b719165d
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to calculate the mean of a pianoroll dataset.
Given a pianoroll pickle file, this script loads the dataset and
calculates the mean of the training set. Then it updates the pickle file
so that the key "train_mean" points to the mean vector.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import numpy as np
import tensorflow as tf
from datasets import sparse_pianoroll_to_dense
tf.app.flags.DEFINE_string('in_file', None,
'Filename of the pickled pianoroll dataset to load.')
tf.app.flags.DEFINE_string('out_file', None,
'Name of the output pickle file. Defaults to in_file, '
'updating the input pickle file.')
tf.app.flags.mark_flag_as_required('in_file')
FLAGS = tf.app.flags.FLAGS
MIN_NOTE = 21
MAX_NOTE = 108
NUM_NOTES = MAX_NOTE - MIN_NOTE + 1
def main(unused_argv):
if FLAGS.out_file is None:
FLAGS.out_file = FLAGS.in_file
with tf.gfile.Open(FLAGS.in_file, 'r') as f:
pianorolls = pickle.load(f)
dense_pianorolls = [sparse_pianoroll_to_dense(p, MIN_NOTE, NUM_NOTES)[0]
for p in pianorolls['train']]
# Concatenate all elements along the time axis.
concatenated = np.concatenate(dense_pianorolls, axis=0)
mean = np.mean(concatenated, axis=0)
pianorolls['train_mean'] = mean
# Write out the whole pickle file, including the train mean.
pickle.dump(pianorolls, open(FLAGS.out_file, 'wb'))
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import random
import re
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string("raw_timit_dir", None,
"Directory containing TIMIT files.")
tf.app.flags.DEFINE_string("out_dir", None,
"Output directory for TFRecord files.")
tf.app.flags.DEFINE_float("valid_frac", 0.05,
"Fraction of train set to use as valid set. "
"Must be between 0.0 and 1.0.")
tf.app.flags.mark_flag_as_required("raw_timit_dir")
tf.app.flags.mark_flag_as_required("out_dir")
FLAGS = tf.app.flags.FLAGS
NUM_TRAIN_FILES = 4620
NUM_TEST_FILES = 1680
SAMPLES_PER_TIMESTEP = 200
# Regexes for reading SPHERE header files.
SAMPLE_COUNT_REGEX = re.compile(r"sample_count -i (\d+)")
SAMPLE_MIN_REGEX = re.compile(r"sample_min -i (-?\d+)")
SAMPLE_MAX_REGEX = re.compile(r"sample_max -i (-?\d+)")
def get_filenames(split):
"""Get all wav filenames from the TIMIT archive."""
path = os.path.join(FLAGS.raw_timit_dir, "TIMIT", split, "*", "*", "*.WAV")
# Sort the output by name so the order is deterministic.
files = sorted(glob.glob(path))
return files
def load_timit_wav(filename):
"""Loads a TIMIT wavfile into a numpy array.
TIMIT wavfiles include a SPHERE header, detailed in the TIMIT docs. The first
line is the header type and the second is the length of the header in bytes.
After the header, the remaining bytes are actual WAV data.
The header includes information about the WAV data such as the number of
samples and minimum and maximum amplitude. This function asserts that the
loaded wav data matches the header.
Args:
filename: The name of the TIMIT wavfile to load.
Returns:
wav: A numpy array containing the loaded wav data.
"""
wav_file = open(filename, "rb")
header_type = wav_file.readline()
header_length_str = wav_file.readline()
# The header length includes the length of the first two lines.
header_remaining_bytes = (int(header_length_str) - len(header_type) -
len(header_length_str))
header = wav_file.read(header_remaining_bytes)
# Read the relevant header fields.
sample_count = int(SAMPLE_COUNT_REGEX.search(header).group(1))
sample_min = int(SAMPLE_MIN_REGEX.search(header).group(1))
sample_max = int(SAMPLE_MAX_REGEX.search(header).group(1))
wav = np.fromstring(wav_file.read(), dtype="int16").astype("float32")
# Check that the loaded data conforms to the header description.
assert len(wav) == sample_count
assert wav.min() == sample_min
assert wav.max() == sample_max
return wav
def preprocess(wavs, block_size, mean, std):
"""Normalize the wav data and reshape it into chunks."""
processed_wavs = []
for wav in wavs:
wav = (wav - mean) / std
wav_length = wav.shape[0]
if wav_length % block_size != 0:
pad_width = block_size - (wav_length % block_size)
wav = np.pad(wav, (0, pad_width), "constant")
assert wav.shape[0] % block_size == 0
wav = wav.reshape((-1, block_size))
processed_wavs.append(wav)
return processed_wavs
def create_tfrecord_from_wavs(wavs, output_file):
"""Writes processed wav files to disk as sharded TFRecord files."""
with tf.python_io.TFRecordWriter(output_file) as builder:
for wav in wavs:
builder.write(wav.astype(np.float32).tobytes())
def main(unused_argv):
train_filenames = get_filenames("TRAIN")
test_filenames = get_filenames("TEST")
num_train_files = len(train_filenames)
num_test_files = len(test_filenames)
num_valid_files = int(num_train_files * FLAGS.valid_frac)
num_train_files -= num_valid_files
print("%d train / %d valid / %d test" % (
num_train_files, num_valid_files, num_test_files))
random.seed(1234)
random.shuffle(train_filenames)
valid_filenames = train_filenames[:num_valid_files]
train_filenames = train_filenames[num_valid_files:]
# Make sure there is no overlap in the train, test, and valid sets.
train_s = set(train_filenames)
test_s = set(test_filenames)
valid_s = set(valid_filenames)
# Disable explicit length testing to make the assertions more readable.
# pylint: disable=g-explicit-length-test
assert len(train_s & test_s) == 0
assert len(train_s & valid_s) == 0
assert len(valid_s & test_s) == 0
# pylint: enable=g-explicit-length-test
train_wavs = [load_timit_wav(f) for f in train_filenames]
valid_wavs = [load_timit_wav(f) for f in valid_filenames]
test_wavs = [load_timit_wav(f) for f in test_filenames]
assert len(train_wavs) + len(valid_wavs) == NUM_TRAIN_FILES
assert len(test_wavs) == NUM_TEST_FILES
# Calculate the mean and standard deviation of the train set.
train_stacked = np.hstack(train_wavs)
train_mean = np.mean(train_stacked)
train_std = np.std(train_stacked)
print("train mean: %f train std: %f" % (train_mean, train_std))
# Process all data, normalizing with the train set statistics.
processed_train_wavs = preprocess(train_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_valid_wavs = preprocess(valid_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_test_wavs = preprocess(test_wavs, SAMPLES_PER_TIMESTEP, train_mean,
train_std)
# Write the datasets to disk.
create_tfrecord_from_wavs(
processed_train_wavs,
os.path.join(FLAGS.out_dir, "train"))
create_tfrecord_from_wavs(
processed_valid_wavs,
os.path.join(FLAGS.out_dir, "valid"))
create_tfrecord_from_wavs(
processed_test_wavs,
os.path.join(FLAGS.out_dir, "test"))
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for creating sequence datasets.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
from scipy.sparse import coo_matrix
import tensorflow as tf
# The default number of threads used to process data in parallel.
DEFAULT_PARALLELISM = 12
def sparse_pianoroll_to_dense(pianoroll, min_note, num_notes):
"""Converts a sparse pianoroll to a dense numpy array.
Given a sparse pianoroll, converts it to a dense numpy array of shape
[num_timesteps, num_notes] where entry i,j is 1.0 if note j is active on
timestep i and 0.0 otherwise.
Args:
pianoroll: A sparse pianoroll object, a list of tuples where the i'th tuple
contains the indices of the notes active at timestep i.
min_note: The minimum note in the pianoroll, subtracted from all notes so
that the minimum note becomes 0.
num_notes: The number of possible different note indices, determines the
second dimension of the resulting dense array.
Returns:
dense_pianoroll: A [num_timesteps, num_notes] numpy array of floats.
num_timesteps: A python int, the number of timesteps in the pianoroll.
"""
num_timesteps = len(pianoroll)
inds = []
for time, chord in enumerate(pianoroll):
# Re-index the notes to start from min_note.
inds.extend((time, note-min_note) for note in chord)
shape = [num_timesteps, num_notes]
values = [1.] * len(inds)
sparse_pianoroll = coo_matrix(
(values, ([x[0] for x in inds], [x[1] for x in inds])),
shape=shape)
return sparse_pianoroll.toarray(), num_timesteps
def create_pianoroll_dataset(path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,
min_note=21,
max_note=108):
"""Creates a pianoroll dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
min_note: The minimum note number of the dataset. For all pianoroll datasets
the minimum note is number 21, and changing this affects the dimension of
the data. This is useful mostly for testing.
max_note: The maximum note number of the dataset. For all pianoroll datasets
the maximum note is number 108, and changing this affects the dimension of
the data. This is useful mostly for testing.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
num_notes = max_note - min_note + 1
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
pianorolls = raw_data[split]
mean = raw_data["train_mean"]
num_examples = len(pianorolls)
def pianoroll_generator():
for sparse_pianoroll in pianorolls:
yield sparse_pianoroll_to_dense(sparse_pianoroll, min_note, num_notes)
dataset = tf.data.Dataset.from_generator(
pianoroll_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_notes], []))
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(batch_size,
padded_shapes=([None, num_notes], []))
def process_pianoroll_batch(data, lengths):
"""Create mean-centered and time-major next-step prediction Tensors."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(mean, dtype=tf.float32,
shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(tf.transpose(
tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_pianoroll_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_speech_dataset(path,
batch_size,
samples_per_timestep=200,
num_parallel_calls=DEFAULT_PARALLELISM,
prefetch_buffer_size=2048,
shuffle=False,
repeat=False):
"""Creates a speech dataset.
Args:
path: The path of a possibly sharded TFRecord file containing the data.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
samples_per_timestep: The number of audio samples per timestep. Used to
reshape the data into sequences of shape [time, samples_per_timestep].
Should not change except for testing -- in all speech datasets 200 is the
number of samples per timestep.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, samples_per_timestep]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, samples_per_timestep].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
"""
filenames = [path]
def read_speech_example(value):
"""Parses a single tf.Example from the TFRecord file."""
decoded = tf.decode_raw(value, out_type=tf.float32)
example = tf.reshape(decoded, [-1, samples_per_timestep])
length = tf.shape(example)[0]
return example, length
# Create the dataset from the TFRecord files
dataset = tf.data.TFRecordDataset(filenames).map(
read_speech_example, num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(prefetch_buffer_size)
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, samples_per_timestep], []))
def process_speech_batch(data, lengths):
"""Creates Tensors for next step prediction."""
data = tf.transpose(data, perm=[1, 0, 2])
lengths = tf.to_int32(lengths)
targets = data
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(data, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_speech_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to run training for sequential latent variable models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import runners
# Shared flags.
tf.app.flags.DEFINE_string("mode", "train",
"The mode of the binary. Must be 'train' or 'test'.")
tf.app.flags.DEFINE_string("model", "vrnn",
"Model choice. Currently only 'vrnn' is supported.")
tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.")
tf.app.flags.DEFINE_string("dataset_type", "pianoroll",
"The type of dataset, either 'pianoroll' or 'speech'.")
tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None,
"The dimension of each vector in the data sequence. "
"Defaults to 88 for pianoroll datasets and 200 for speech "
"datasets. Should not need to be changed except for "
"testing.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample "
"algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.")
# Training flags.
tf.app.flags.DEFINE_string("bound", "fivo",
"The bound to optimize. Can be 'elbo', 'iwae', or 'fivo'.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps "
"per sequence.")
tf.app.flags.DEFINE_float("learning_rate", 0.0002,
"The learning rate for ADAM.")
tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.")
# Distributed training flags.
tf.app.flags.DEFINE_string("master", "",
"The BNS name of the TensorFlow master to use.")
tf.app.flags.DEFINE_integer("task", 0,
"Task id of the replica running the training.")
tf.app.flags.DEFINE_integer("ps_tasks", 0,
"Number of tasks in the ps job. If 0 no ps job is used.")
tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.")
# Evaluation flags.
tf.app.flags.DEFINE_string("split", "train",
"Split to evaluate the model on. Can be 'train', 'valid', or 'test'.")
FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88
SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.data_dimension is None:
if FLAGS.dataset_type == "pianoroll":
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
elif FLAGS.dataset_type == "speech":
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
if FLAGS.mode == "train":
runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VRNN classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
class VRNNCell(snt.AbstractModule):
"""Implementation of a Variational Recurrent Neural Network (VRNN).
Introduced in "A Recurrent Latent Variable Model for Sequential data"
by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
The VRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
The VRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
is conditioned on the previous sequence element, x_{t-1}, as well as the
latent state from the previous timestep, z_{t-1}.
In this implementation of the VRNN the latent state z_t is Gaussian. The
model's prior over z_t is distributed as Normal(mu_t, diag(sigma_t^2)) where
mu_t and sigma_t are the mean and standard deviation output from a fully
connected network that accepts the rnn hidden state h_t as input.
The approximate posterior (also known as q or the encoder in the VAE
framework) is similar to the prior except that it is conditioned on the
current target, x_t, as well as h_t via a fully connected network.
This implementation uses the 'res_q' parameterization of the approximate
posterior, meaning that instead of directly predicting the mean of z_t, the
approximate posterior predicts the 'residual' from the prior's mean. This is
explored more in section 3.3 of https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the approximate
posterior and the reparameterization trick is used to provide low-variance
gradients.
The generative distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the VRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_feat_extractor and latent_feat_extractor respectively.
There are a few differences between this exposition and the paper.
First, the indexing scheme for h_t is different than the paper's -- what the
paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
uses VAE terminology to refer to the different internal networks, so it
refers to the approximate posterior as the encoder and the generative
distribution as the decoder. This implementation also renamed the functions
phi_x and phi_z in the paper to data_feat_extractor and latent_feat_extractor.
"""
def __init__(self,
rnn_cell,
data_feat_extractor,
latent_feat_extractor,
prior,
approx_posterior,
generative,
random_seed=None,
name="vrnn"):
"""Creates a VRNN cell.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the VRNN. The inputs to the RNN will be the
encoded latent state of the previous timestep with shape
[batch_size, encoded_latent_size] as well as the encoded input of the
current timestep, a Tensor of shape [batch_size, encoded_data_size].
data_feat_extractor: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the VRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_feat_extractor: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
prior: A callable that implements the prior p(z_t|h_t). Must accept as
argument the previous RNN hidden state and return a
tf.contrib.distributions.Normal distribution conditioned on the input.
approx_posterior: A callable that implements the approximate posterior
q(z_t|h_t,x_t). Must accept as arguments the encoded target of the
current timestep and the previous RNN hidden state. Must return
a tf.contrib.distributions.Normal distribution conditioned on the
inputs.
generative: A callable that implements the generative distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.contrib.distributions.Distribution that can be used to evaluate
the logprob of the targets.
random_seed: The seed for the random ops. Used mainly for testing.
name: The name of this VRNN.
"""
super(VRNNCell, self).__init__(name=name)
self.rnn_cell = rnn_cell
self.data_feat_extractor = data_feat_extractor
self.latent_feat_extractor = latent_feat_extractor
self.prior = prior
self.approx_posterior = approx_posterior
self.generative = generative
self.random_seed = random_seed
self.encoded_z_size = latent_feat_extractor.output_size
self.state_size = (self.rnn_cell.state_size, self.encoded_z_size)
def zero_state(self, batch_size, dtype):
"""The initial state of the VRNN.
Contains the initial state of the RNN as well as a vector of zeros
corresponding to z_0.
Args:
batch_size: The batch size.
dtype: The data type of the VRNN.
Returns:
zero_state: The initial state of the VRNN.
"""
return (self.rnn_cell.zero_state(batch_size, dtype),
tf.zeros([batch_size, self.encoded_z_size], dtype=dtype))
def _build(self, observations, state, mask):
"""Computes one timestep of the VRNN.
Args:
observations: The observations at the current timestep, a tuple
containing the model inputs and targets as Tensors of shape
[batch_size, data_size].
state: The current state of the VRNN
mask: Tensor of shape [batch_size], 1.0 if the current timestep is active
active, 0.0 if it is not active.
Returns:
log_q_z: The logprob of the latent state according to the approximate
posterior.
log_p_z: The logprob of the latent state according to the prior.
log_p_x_given_z: The conditional log-likelihood, i.e. logprob of the
observation according to the generative distribution.
kl: The analytic kl divergence from q(z) to p(z).
state: The new state of the VRNN.
"""
inputs, targets = observations
rnn_state, prev_latent_encoded = state
# Encode the data.
inputs_encoded = self.data_feat_extractor(inputs)
targets_encoded = self.data_feat_extractor(targets)
# Run the RNN cell.
rnn_inputs = tf.concat([inputs_encoded, prev_latent_encoded], axis=1)
rnn_out, new_rnn_state = self.rnn_cell(rnn_inputs, rnn_state)
# Create the prior and approximate posterior distributions.
latent_dist_prior = self.prior(rnn_out)
latent_dist_q = self.approx_posterior(rnn_out, targets_encoded,
prior_mu=latent_dist_prior.loc)
# Sample the new latent state z and encode it.
latent_state = latent_dist_q.sample(seed=self.random_seed)
latent_encoded = self.latent_feat_extractor(latent_state)
# Calculate probabilities of the latent state according to the prior p
# and approximate posterior q.
log_q_z = tf.reduce_sum(latent_dist_q.log_prob(latent_state), axis=-1)
log_p_z = tf.reduce_sum(latent_dist_prior.log_prob(latent_state), axis=-1)
analytic_kl = tf.reduce_sum(
tf.contrib.distributions.kl_divergence(
latent_dist_q, latent_dist_prior),
axis=-1)
# Create the generative dist. and calculate the logprob of the targets.
generative_dist = self.generative(latent_encoded, rnn_out)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(targets), axis=-1)
return (log_q_z, log_p_z, log_p_x_given_z, analytic_kl,
(new_rnn_state, latent_encoded))
_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
def create_vrnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=None,
fcnet_hidden_sizes=None,
encoded_data_size=None,
encoded_latent_size=None,
sigma_min=0.0,
raw_sigma_bias=0.25,
generative_bias_init=0.0,
initializers=None,
random_seed=None):
"""A factory method for creating VRNN cells.
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the VRNN.
generative_class: The class of the generative distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the VRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
generative_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the generative distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the VRNN resampling operations.
Returns:
model: A VRNNCell object.
"""
if rnn_hidden_size is None:
rnn_hidden_size = latent_size
if fcnet_hidden_sizes is None:
fcnet_hidden_sizes = [latent_size]
if encoded_data_size is None:
encoded_data_size = latent_size
if encoded_latent_size is None:
encoded_latent_size = latent_size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
data_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_data_size],
initializers=initializers,
name="data_feat_extractor")
latent_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
initializers=initializers,
name="latent_feat_extractor")
prior = ConditionalNormalDistribution(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="prior")
approx_posterior = NormalApproximatePosterior(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="approximate_posterior")
if generative_class == ConditionalBernoulliDistribution:
generative = ConditionalBernoulliDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
bias_init=generative_bias_init,
name="generative")
else:
generative = ConditionalNormalDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="generative")
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
return VRNNCell(rnn_cell, data_feat_extractor, latent_feat_extractor,
prior, approx_posterior, generative, random_seed=random_seed)
class ConditionalNormalDistribution(object):
"""A Normal distribution conditioned on Tensor inputs via a fc network."""
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, name="conditional_normal_distribution"):
"""Creates a conditional Normal distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
sigma_min: The minimum standard deviation allowed, a scalar.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the fully connected network. Set to 0.25 by default to
prevent standard deviations close to 0.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intitializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
name: The name of this distribution, used for sonnet scoping.
"""
self.sigma_min = sigma_min
self.raw_sigma_bias = raw_sigma_bias
self.name = name
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [2*size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list, **unused_kwargs):
"""Computes the parameters of a normal distribution based on the inputs."""
inputs = tf.concat(tensor_list, axis=1)
outs = self.fcnet(inputs)
mu, sigma = tf.split(outs, 2, axis=1)
sigma = tf.maximum(tf.nn.softplus(sigma + self.raw_sigma_bias),
self.sigma_min)
return mu, sigma
def __call__(self, *args, **kwargs):
"""Creates a normal distribution conditioned on the inputs."""
mu, sigma = self.condition(args, **kwargs)
return tf.contrib.distributions.Normal(loc=mu, scale=sigma)
class ConditionalBernoulliDistribution(object):
"""A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
def __init__(self, size, hidden_layer_sizes, hidden_activation_fn=tf.nn.relu,
initializers=None, bias_init=0.0,
name="conditional_bernoulli_distribution"):
"""Creates a conditional Bernoulli distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intiializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
bias_init: A scalar or vector Tensor that is added to the output of the
fully-connected network that parameterizes the mean of this
distribution.
name: The name of this distribution, used for sonnet scoping.
"""
self.bias_init = bias_init
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list):
"""Computes the p parameter of the Bernoulli distribution."""
inputs = tf.concat(tensor_list, axis=1)
return self.fcnet(inputs) + self.bias_init
def __call__(self, *args):
p = self.condition(args)
return tf.contrib.distributions.Bernoulli(logits=p)
class NormalApproximatePosterior(ConditionalNormalDistribution):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def condition(self, tensor_list, prior_mu):
"""Generates the mean and variance of the normal distribution.
Args:
tensor_list: The list of Tensors to condition on. Will be concatenated and
fed through a fully connected network.
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
return mu + prior_mu, sigma
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A set of utils for dealing with nested lists and tuples of Tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.util import nest
def map_nested(map_fn, nested):
"""Executes map_fn on every element in a (potentially) nested structure.
Args:
map_fn: A callable to execute on each element in 'nested'.
nested: A potentially nested combination of sequence objects. Sequence
objects include tuples, lists, namedtuples, and all subclasses of
collections.Sequence except strings. See nest.is_sequence for details.
For example [1, ('hello', 4.3)] is a nested structure containing elements
1, 'hello', and 4.3.
Returns:
out_structure: A potentially nested combination of sequence objects with the
same structure as the 'nested' input argument. out_structure
contains the result of applying map_fn to each element in 'nested'. For
example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
"""
out = map(map_fn, nest.flatten(nested))
return nest.pack_sequence_as(nested, out)
def tile_tensors(tensors, multiples):
"""Tiles a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors with rank
greater than or equal to the length of 'multiples'. The Tensors do not
need to have the same rank, but their rank must not be dynamic.
multiples: A python list of ints indicating how to tile each Tensor
in 'tensors'. Similar to the 'multiples' argument to tf.tile.
Returns:
tiled_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tensors' input argument. Contains the result of
applying tf.tile to each Tensor in 'tensors'. When the rank of a Tensor
in 'tensors' is greater than the length of multiples, multiples is padded
at the end with 1s. For example when tiling a 4-dimensional Tensor with
multiples [3, 4], multiples would be padded to [3, 4, 1, 1] before tiling.
"""
def tile_fn(x):
return tf.tile(x, multiples + [1]*(x.shape.ndims - len(multiples)))
return map_nested(tile_fn, tensors)
def gather_tensors(tensors, indices):
"""Performs a tf.gather operation on a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors.
indices: The indices to use for the gather operation.
Returns:
gathered_tensors: A potentially nested tuple or list of Tensors with the
same structure as the 'tensors' input argument. Contains the result of
applying tf.gather(x, indices) on each element x in 'tensors'.
"""
return map_nested(lambda x: tf.gather(x, indices), tensors)
def tas_for_tensors(tensors, length):
"""Unstacks a set of Tensors into TensorArrays.
Args:
tensors: A potentially nested tuple or list of Tensors with length in the
first dimension greater than or equal to the 'length' input argument.
length: The desired length of the TensorArrays.
Returns:
tensorarrays: A potentially nested tuple or list of TensorArrays with the
same structure as 'tensors'. Contains the result of unstacking each Tensor
in 'tensors'.
"""
def map_fn(x):
ta = tf.TensorArray(x.dtype, length, name=x.name.split(':')[0] + '_ta')
return ta.unstack(x[:length, :])
return map_nested(map_fn, tensors)
def read_tas(tas, index):
"""Performs a read operation on a set of TensorArrays.
Args:
tas: A potentially nested tuple or list of TensorArrays with length greater
than 'index'.
index: The location to read from.
Returns:
read_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tas' input argument. Contains the result of
performing a read operation at 'index' on each TensorArray in 'tas'.
"""
return map_nested(lambda ta: ta.read(index), tas)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""High-level code for creating and running FIVO-related Tensorflow graphs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import tensorflow as tf
import bounds
from data import datasets
from models import vrnn
def create_dataset_and_model(config, split, shuffle, repeat):
"""Creates the dataset and model for a given config.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. This function expects the properties
batch_size, dataset_path, dataset_type, and latent_size to be defined.
split: The dataset split to load.
shuffle: If true, shuffle the dataset randomly.
repeat: If true, repeat the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension].
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
model: A vrnn.VRNNCell model object.
"""
if config.dataset_type == "pianoroll":
inputs, targets, lengths, mean = datasets.create_pianoroll_dataset(
config.dataset_path, split, config.batch_size, shuffle=shuffle,
repeat=repeat)
# Convert the mean of the training set to logit space so it can be used to
# initialize the bias of the generative distribution.
generative_bias_init = -tf.log(
1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1)
generative_distribution_class = vrnn.ConditionalBernoulliDistribution
elif config.dataset_type == "speech":
inputs, targets, lengths = datasets.create_speech_dataset(
config.dataset_path, config.batch_size,
samples_per_timestep=config.data_dimension, prefetch_buffer_size=1,
shuffle=False, repeat=False)
generative_bias_init = None
generative_distribution_class = vrnn.ConditionalNormalDistribution
model = vrnn.create_vrnn(inputs.get_shape().as_list()[2],
config.latent_size,
generative_distribution_class,
generative_bias_init=generative_bias_init,
raw_sigma_bias=0.5)
return inputs, targets, lengths, model
def restore_checkpoint_if_exists(saver, sess, logdir):
"""Looks for a checkpoint and restores the session from it if found.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
Returns:
True if a checkpoint was found and restored, False otherwise.
"""
checkpoint = tf.train.get_checkpoint_state(logdir)
if checkpoint:
checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
full_checkpoint_path = os.path.join(logdir, checkpoint_name)
saver.restore(sess, full_checkpoint_path)
return True
return False
def wait_for_checkpoint(saver, sess, logdir):
"""Loops until the session is restored from a checkpoint in logdir.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
"""
while True:
if restore_checkpoint_if_exists(saver, sess, logdir):
break
else:
tf.logging.info("Checkpoint not found in %s, sleeping for 60 seconds."
% logdir)
time.sleep(60)
def run_train(config):
"""Runs training for a sequential latent variable model.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_logging_hook(step, bound_value):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + " bound"
if config.normalize_by_seq_len:
bound_label += " per timestep"
else:
bound_label += " per sequence"
def summary_formatter(log_dict):
return "Step %d, %s: %f" % (
log_dict["step"], bound_label, log_dict["bound_value"])
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "bound_value": bound_value},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_loss():
"""Creates the loss to be optimized.
Returns:
bound: A float Tensor containing the value of the bound that is
being optimized.
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
"""
inputs, targets, lengths, model = create_dataset_and_model(
config, split="train", shuffle=True, repeat=True)
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
elif config.bound == "iwae":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
elif config.bound == "fivo":
ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
# Compute loss scaled by number of timesteps.
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
ll_per_seq = tf.reduce_mean(ll_per_seq)
tf.summary.scalar("train_ll_per_seq", ll_per_seq)
tf.summary.scalar("train_ll_per_t", ll_per_t)
if config.normalize_by_seq_len:
return ll_per_t, -ll_per_t
else:
return ll_per_seq, -ll_per_seq
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
bound, loss = create_loss()
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, train_op, global_step
device = tf.train.replica_device_setter(ps_tasks=config.ps_tasks)
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
with tf.device(device):
bound, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound)
start_training = not config.stagger_workers
with tf.train.MonitoredTrainingSession(
master=config.master,
is_chief=config.task == 0,
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > config.max_steps: break
if config.task > 0 and not start_training:
cur_step = sess.run(global_step)
tf.logging.info("task %d not active yet, sleeping at step %d" %
(config.task, cur_step))
time.sleep(30)
if cur_step >= config.task * 1000:
start_training = True
else:
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Runs evaluation for a sequential latent variable model.
This method runs only one evaluation over the dataset, writes summaries to
disk, and then terminates. It does not loop indefinitely.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_graph():
"""Creates the evaluation graph.
Returns:
lower_bounds: A tuple of float Tensors containing the values of the 3
evidence lower bounds, summed across the batch.
total_batch_length: The total number of timesteps in the batch, summed
across batch examples.
batch_size: The batch size.
global_step: The global step the checkpoint was loaded from.
"""
global_step = tf.train.get_or_create_global_step()
inputs, targets, lengths, model = create_dataset_and_model(
config, split=config.split, shuffle=False, repeat=False)
# Compute lower bounds on the log likelihood.
elbo_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
iwae_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
fivo_ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
batch_size = tf.shape(lengths)[0]
total_batch_length = tf.reduce_sum(lengths)
return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
global_step)
def average_bounds_over_dataset(lower_bounds, total_batch_length, batch_size,
sess):
"""Computes the values of the bounds, averaged over the datset.
Args:
lower_bounds: Tuple of float Tensors containing the values of the bounds
evaluated on a single batch.
total_batch_length: Integer Tensor that represents the total number of
timesteps in the current batch.
batch_size: Integer Tensor containing the batch size. This can vary if the
requested batch_size does not evenly divide the size of the dataset.
sess: A TensorFlow Session object.
Returns:
ll_per_t: A length 3 numpy array of floats containing each bound's average
value, normalized by the total number of timesteps in the datset. Can
be interpreted as a lower bound on the average log likelihood per
timestep in the dataset.
ll_per_seq: A length 3 numpy array of floats containing each bound's
average value, normalized by the number of sequences in the dataset.
Can be interpreted as a lower bound on the average log likelihood per
sequence in the datset.
"""
total_ll = np.zeros(3, dtype=np.float64)
total_n_elems = 0.0
total_length = 0.0
while True:
try:
outs = sess.run([lower_bounds, batch_size, total_batch_length])
except tf.errors.OutOfRangeError:
break
total_ll += outs[0]
total_n_elems += outs[1]
total_length += outs[2]
ll_per_t = total_ll / total_length
ll_per_seq = total_ll / total_n_elems
return ll_per_t, ll_per_seq
def summarize_lls(lls_per_t, lls_per_seq, summary_writer, step):
"""Creates log-likelihood lower bound summaries and writes them to disk.
Args:
lls_per_t: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of timesteps.
lls_per_seq: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of sequences.
summary_writer: A tf.SummaryWriter.
step: The current global step.
"""
def scalar_summary(name, value):
value = tf.Summary.Value(tag=name, simple_value=value)
return tf.Summary(value=[value])
for i, bound in enumerate(["elbo", "iwae", "fivo"]):
per_t_summary = scalar_summary("%s/%s_ll_per_t" % (config.split, bound),
lls_per_t[i])
per_seq_summary = scalar_summary("%s/%s_ll_per_seq" %
(config.split, bound),
lls_per_seq[i])
summary_writer.add_summary(per_t_summary, global_step=step)
summary_writer.add_summary(per_seq_summary, global_step=step)
summary_writer.flush()
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
lower_bounds, total_batch_length, batch_size, global_step = create_graph()
summary_dir = config.logdir + "/" + config.split
summary_writer = tf.summary.FileWriter(
summary_dir, flush_secs=15, max_queue=100)
saver = tf.train.Saver()
with tf.train.SingularMonitoredSession() as sess:
wait_for_checkpoint(saver, sess, config.logdir)
step = sess.run(global_step)
tf.logging.info("Model restored from step %d, evaluating." % step)
ll_per_t, ll_per_seq = average_bounds_over_dataset(
lower_bounds, total_batch_length, batch_size, sess)
summarize_lls(ll_per_t, ll_per_seq, summary_writer, step)
tf.logging.info("%s elbo ll/t: %f, iwae ll/t: %f fivo ll/t: %f",
config.split, ll_per_t[0], ll_per_t[1], ll_per_t[2])
tf.logging.info("%s elbo ll/seq: %f, iwae ll/seq: %f fivo ll/seq: %f",
config.split, ll_per_seq[0], ll_per_seq[1], ll_per_seq[2])
# TFGAN Examples
[TFGAN](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan) is a lightweight library for training and evaluating Generative
Adversarial Networks (GANs). GANs have been in a wide range of tasks
including [image translation](https://arxiv.org/abs/1703.10593), [superresolution](https://arxiv.org/abs/1609.04802), and [data augmentation](https://arxiv.org/abs/1612.07828). This directory contains fully-working examples
that demonstrate the ease and flexibility of TFGAN. Each subdirectory contains a
different working example. The sub-sections below describe each of the problems,
and include some sample outputs. We've also included a [jupyter notebook](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb), which
provides a walkthrough of TFGAN.
## Contacts
Maintainers of TFGAN:
* Joel Shor,
github: [joel-shor](https://github.com/joel-shor)
## Table of contents
1. [MNIST](#mnist)
1. [MNIST with GANEstimator](#mnist_estimator)
1. [CIFAR10](#cifar10)
1. [Image compression (coming soon)](#compression)
## MNIST
<a id='mnist'></a>
We train a simple generator to produce [MNIST digits](http://yann.lecun.com/exdb/mnist/).
The unconditional case maps noise to MNIST digits. The conditional case maps
noise and digit class to MNIST digits. [InfoGAN](https://arxiv.org/abs/1606.03657) learns to produce
digits of a given class without labels, as well as controlling style. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/mnist/networks.py).
We use a classifier trained on MNIST digit classification for evaluation.
### Unconditional MNIST
<img src="g3doc/mnist_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional MNIST
<img src="g3doc/mnist_conditional_gan.png" title="Conditional GAN" width="330" />
### InfoGAN MNIST
<img src="g3doc/mnist_infogan.png" title="InfoGAN" width="330" />
## MNIST with GANEstimator
<a id='mnist_estimator'></a>
This setup is exactly the same as in the [unconditional MNIST example](#mnist), but
uses the `tf.Learn` `GANEstimator`.
<img src="g3doc/mnist_estimator_unconditional_gan.png" title="Unconditional GAN" width="330" />
## CIFAR10
<a id='cifar10'></a>
We train a [DCGAN generator](https://arxiv.org/abs/1511.06434) to produce [CIFAR10 images](https://www.cs.toronto.edu/~kriz/cifar.html).
The unconditional case maps noise to CIFAR10 images. The conditional case maps
noise and image class to CIFAR10 images. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/cifar/networks.py).
We use the [Inception Score](https://arxiv.org/abs/1606.03498) to evaluate the images.
### Unconditional CIFAR10
<img src="g3doc/cifar_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional CIFAR10
<img src="g3doc/cifar_conditional_gan.png" title="Conditional GAN" width="330" />
## Image compression
<a id='compression'></a>
In neural image compression, we attempt to reduce an image to a smaller representation
such that we can recreate the original image as closely as possible. See [`Full Resolution Image Compression with Recurrent Neural Networks`](https://arxiv.org/abs/1608.05148) for more details on using neural networks
for image compression.
In this example, we train an encoder to compress images to a compressed binary
representation and a decoder to map the binary representation back to the image.
We treat both systems together (encoder -> decoder) as the generator.
A typical image compression trained on L1 pixel loss will decode into
blurry images. We use an adversarial loss to force the outputs to be more
plausible.
This example also highlights the following infrastructure challenges:
* When you have custom code to keep track of your variables
Some other notes on the problem:
* Since the network is fully convolutional, we train on image patches.
* Bottleneck layer is floating point during training and binarized during
evaluation.
### Results
#### No adversarial loss
<img src="g3doc/compression_wf0.png" title="No adversarial loss" width="500" />
#### Adversarial loss
<img src="g3doc/compression_wf10000.png" title="With adversarial loss" width="500" />
### Architectures
#### Compression Network
The compression network is a DCGAN discriminator for the encoder and a DCGAN
generator for the decoder from [`Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks`](https://arxiv.org/abs/1511.06434).
The binarizer adds uniform noise during training then binarizes during eval, as in
[`End-to-end Optimized Image Compression`](https://arxiv.org/abs/1611.01704).
#### Discriminator
The discriminator looks at 70x70 patches, as in
[`Image-to-Image Translation with Conditional Adversarial Networks`](https://arxiv.org/abs/1611.07004).
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(batch_size, dataset_dir, dataset_name='cifar10',
split_name='train', one_hot=True):
"""Provides batches of CIFAR data.
Args:
batch_size: The number of images in each batch.
dataset_dir: The directory where the CIFAR10 data can be found. If `None`,
use default.
dataset_name: Name of the dataset.
split_name: Should be either 'train' or 'test'.
one_hot: Output one hot vector instead of int32 label.
Returns:
images: A `Tensor` of size [batch_size, 32, 32, 3]. Output pixel values are
in [-1, 1].
labels: Either (1) one_hot_labels if `one_hot` is `True`
A `Tensor` of size [batch_size, num_classes], where each row has a
single element set to one and the rest set to zeros.
Or (2) labels if `one_hot` is `False`
A `Tensor` of size [batch_size], holding the labels as integers.
num_samples: The number of total samples in the dataset.
num_classes: The number of classes in the dataset.
Raises:
ValueError: if the split_name is not either 'train' or 'test'.
"""
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=(split_name == 'train'))
[image, label] = provider.get(['image', 'label'])
# Preprocess the images.
image = (tf.to_float(image) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=32,
capacity=5 * batch_size)
labels = tf.reshape(labels, [-1])
if one_hot:
labels = tf.one_hot(labels, dataset.num_classes)
return images, labels, dataset.num_samples, dataset.num_classes
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
images, labels, num_samples, num_classes = data_provider.provide_data(
batch_size, dataset_dir)
self.assertEqual(50000, num_samples)
self.assertEqual(10, num_classes)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out, labels_out = sess.run([images, labels])
self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
expected_label_shape = (batch_size, 10)
self.assertEqual(expected_label_shape, labels_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained CIFAR model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import util
flags = tf.flags
FLAGS = tf.flags.FLAGS
tfgan = tf.contrib.gan
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/cifar10/',
'Directory where the results are saved to.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('num_images_generated', 100,
'Number of images to generate at once.')
flags.DEFINE_integer('num_inception_images', 10,
'The number of images to run through Inception at once.')
flags.DEFINE_boolean('eval_real_images', False,
'If `True`, run Inception network on real images.')
flags.DEFINE_boolean('conditional_eval', False,
'If `True`, set up a conditional GAN.')
flags.DEFINE_boolean('eval_frechet_inception_distance', True,
'If `True`, compute Frechet Inception distance using real '
'images and generated images.')
flags.DEFINE_integer('num_images_per_class', 10,
'When a conditional generator is used, this is the number '
'of images to display per class.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception.
with tf.name_scope('inputs'):
real_data, num_classes = _get_real_data(
FLAGS.num_images_generated, FLAGS.dataset_dir)
generated_data = _get_generated_data(
FLAGS.num_images_generated, FLAGS.conditional_eval, num_classes)
# Compute Frechet Inception Distance.
if FLAGS.eval_frechet_inception_distance:
fid = util.get_frechet_inception_distance(
real_data, generated_data, FLAGS.num_images_generated,
FLAGS.num_inception_images)
tf.summary.scalar('frechet_inception_distance', fid)
# Compute normal Inception scores.
if FLAGS.eval_real_images:
inc_score = util.get_inception_scores(
real_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
else:
inc_score = util.get_inception_scores(
generated_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
tf.summary.scalar('inception_score', inc_score)
# If conditional, display an image grid of difference classes.
if FLAGS.conditional_eval and not FLAGS.eval_real_images:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
tf.summary.image('generated_data', reshaped_imgs, max_outputs=1)
# Create ops that write images to disk.
image_write_ops = None
if FLAGS.conditional_eval:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'conditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
else:
if FLAGS.num_images_generated >= 100:
reshaped_imgs = tfgan.eval.image_reshaper(
generated_data[:100], num_cols=FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'unconditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
def _get_real_data(num_images_generated, dataset_dir):
"""Get real images."""
data, _, _, num_classes = data_provider.provide_data(
num_images_generated, dataset_dir)
return data, num_classes
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
"""Get generated images."""
noise = tf.random_normal([num_images_generated, 64])
# If conditional, generate class-specific images.
if conditional_eval:
conditioning = util.get_generator_conditioning(
num_images_generated, num_classes)
generator_inputs = (noise, conditioning)
generator_fn = networks.conditional_generator
else:
generator_inputs = noise
generator_fn = networks.generator
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('Generator'):
data = generator_fn(generator_inputs)
return data
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class EvalTest(tf.test.TestCase):
def _test_build_graph_helper(self, eval_real_images, conditional_eval):
FLAGS.eval_real_images = eval_real_images
FLAGS.conditional_eval = conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
# expensive.
with mock.patch.object(
eval.util, 'get_frechet_inception_distance') as mock_fid:
with mock.patch.object(eval.util, 'get_inception_scores') as mock_iscore:
mock_fid.return_value = 1.0
mock_iscore.return_value = 1.0
eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True, False)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False, False)
def test_build_graph_generateddataconditional(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the CIFAR dataset.
# 2. Trains an unconditional or conditional model on the CIFAR training set.
# 3. Evaluates the models and writes sample images to disk.
#
#
# With the default batch size and number of steps, train times are:
#
# Usage:
# cd models/research/gan/cifar
# ./launch_jobs.sh ${gan_type} ${git_repo}
set -e
# Type of GAN to run. Right now, options are `unconditional`, `conditional`, or
# `infogan`.
gan_type=$1
if ! [[ "$gan_type" =~ ^(unconditional|conditional) ]]; then
echo "'gan_type' must be one of: 'unconditional', 'conditional'."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/cifar-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/cifar-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/cifar-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/gan:$git_repo/research/slim
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
python "${git_repo}/research/slim/download_and_convert_data.py" \
--dataset_name=cifar10 \
--dataset_dir=${DATASET_DIR}
# Run unconditional GAN.
if [[ "$gan_type" == "unconditional" ]]; then
UNCONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/unconditional"
UNCONDITIONAL_EVAL_DIR="${EVAL_DIR}/unconditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training unconditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${UNCONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="unconditional" \
--alsologtostderr
Banner "Finished training unconditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of unconditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${UNCONDITIONAL_TRAIN_DIR} \
--eval_dir=${UNCONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=false \
--max_number_of_evaluation=1
Banner "Finished unconditional evaluation. See ${UNCONDITIONAL_EVAL_DIR} for output images."
fi
# Run conditional GAN.
if [[ "$gan_type" == "conditional" ]]; then
CONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/conditional"
CONDITIONAL_EVAL_DIR="${EVAL_DIR}/conditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training conditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${CONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="conditional" \
--alsologtostderr
Banner "Finished training conditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of conditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${CONDITIONAL_TRAIN_DIR} \
--eval_dir=${CONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=true \
--max_number_of_evaluation=1
Banner "Finished conditional evaluation. See ${CONDITIONAL_EVAL_DIR} for output images."
fi
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import dcgan
tfgan = tf.contrib.gan
def _last_conv_layer(end_points):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
conv_list.sort()
return end_points[conv_list[-1]]
def generator(noise):
"""Generator to produce CIFAR images.
Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
images, _ = dcgan.generator(noise)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def conditional_generator(inputs):
"""Generator to produce CIFAR images.
Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def discriminator(img, unused_conditioning):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, _ = dcgan.discriminator(img)
return logits
# (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created.
def conditional_discriminator(img, conditioning):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, end_points = dcgan.discriminator(img)
# Condition the last convolution layer.
_, one_hot_labels = conditioning
net = _last_conv_layer(end_points)
net = tfgan.features.condition_tensor_from_onehot(
tf.contrib.layers.flatten(net), one_hot_labels)
logits = tf.contrib.layers.linear(net, 1)
return logits
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.cifar.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import networks
class NetworksTest(tf.test.TestCase):
def test_generator(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
image = networks.generator(noise)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_generator_conditional(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
conditioning = tf.one_hot([0] * batch_size, 10)
image = networks.conditional_generator((noise, conditioning))
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_discriminator(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
dis_output = networks.discriminator(image, None)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
def test_discriminator_conditional(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
conditioning = (None, tf.one_hot([0] * batch_size, 10))
dis_output = networks.conditional_discriminator(image, conditioning)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a generator on CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags = tf.flags
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/cifar/',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_boolean(
'conditional', False,
'If `True`, set up a conditional GAN. If False, it is unconditional.')
# Sync replicas flags.
flags.DEFINE_boolean(
'use_sync_replicas', True,
'If `True`, use sync replicas. Otherwise use async.')
flags.DEFINE_integer(
'worker_replicas', 10,
'The number of gradients to collect before updating params. Only used '
'with sync replicas.')
flags.DEFINE_integer(
'backup_workers', 1,
'Number of workers to be kept as backup in the sync replicas case.')
FLAGS = flags.FLAGS
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.name_scope('inputs'):
with tf.device('/cpu:0'):
images, one_hot_labels, _, _ = data_provider.provide_data(
FLAGS.batch_size, FLAGS.dataset_dir)
# Define the GANModel tuple.
noise = tf.random_normal([FLAGS.batch_size, 64])
if FLAGS.conditional:
generator_fn = networks.conditional_generator
discriminator_fn = networks.conditional_discriminator
generator_inputs = (noise, one_hot_labels)
else:
generator_fn = networks.generator
discriminator_fn = networks.discriminator
generator_inputs = noise
gan_model = tfgan.gan_model(
generator_fn,
discriminator_fn,
real_data=images,
generator_inputs=generator_inputs)
tfgan.eval.add_gan_model_image_summaries(gan_model)
# Get the GANLoss tuple. Use the selected GAN loss functions.
# (joelshor): Put this block in `with tf.name_scope('loss'):` when
# cl/171610946 goes into the opensource release.
gan_loss = tfgan.gan_loss(gan_model,
gradient_penalty_weight=1.0,
add_summaries=True)
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
with tf.name_scope('train'):
gen_lr, dis_lr = _learning_rate()
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr, FLAGS.use_sync_replicas)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
sync_hooks = ([gen_opt.make_session_run_hook(FLAGS.task == 0),
dis_opt.make_session_run_hook(FLAGS.task == 0)]
if FLAGS.use_sync_replicas else [])
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
tfgan.gan_train(
train_ops,
hooks=(
[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)] +
sync_hooks),
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0)
def _learning_rate():
generator_lr = tf.train.exponential_decay(
learning_rate=0.0001,
global_step=tf.train.get_or_create_global_step(),
decay_steps=100000,
decay_rate=0.9,
staircase=True)
discriminator_lr = 0.001
return generator_lr, discriminator_lr
def _optimizer(gen_lr, dis_lr, use_sync_replicas):
"""Get an optimizer, that's optionally synchronous."""
generator_opt = tf.train.RMSPropOptimizer(gen_lr, decay=.9, momentum=0.1)
discriminator_opt = tf.train.RMSPropOptimizer(dis_lr, decay=.95, momentum=0.1)
def _make_sync(opt):
return tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=FLAGS.worker_replicas-FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
if use_sync_replicas:
generator_opt = _make_sync(generator_opt)
discriminator_opt = _make_sync(discriminator_opt)
return generator_opt, discriminator_opt
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def _test_build_graph_helper(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
FLAGS.batch_size = 16
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None, None)
train.main(None)
def test_build_graph_unconditional(self):
self._test_build_graph_helper(False, False)
def test_build_graph_conditional(self):
self._test_build_graph_helper(True, False)
def test_build_graph_syncreplicas(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for training and evaluating a TFGAN CIFAR example."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tfgan = tf.contrib.gan
def get_generator_conditioning(batch_size, num_classes):
"""Generates TFGAN conditioning inputs for evaluation.
Args:
batch_size: A Python integer. The desired batch size.
num_classes: A Python integer. The number of classes.
Returns:
A Tensor of one-hot vectors corresponding to an even distribution over
classes.
Raises:
ValueError: If `batch_size` isn't evenly divisible by `num_classes`.
"""
if batch_size % num_classes != 0:
raise ValueError('`batch_size` %i must be evenly divisible by '
'`num_classes` %i.' % (batch_size, num_classes))
labels = [lbl for lbl in xrange(num_classes)
for _ in xrange(batch_size // num_classes)]
return tf.one_hot(tf.constant(labels), num_classes)
def get_image_grid(images, batch_size, num_classes, num_images_per_class):
"""Combines images from each class in a single summary image.
Args:
images: Tensor of images that are arranged by class. The first
`batch_size / num_classes` images belong to the first class, the second
group belong to the second class, etc. Shape is
[batch, width, height, channels].
batch_size: Python integer. Batch dimension.
num_classes: Number of classes to show.
num_images_per_class: Number of image examples per class to show.
Raises:
ValueError: If the batch dimension of `images` is known at graph
construction, and it isn't `batch_size`.
ValueError: If there aren't enough images to show
`num_classes * num_images_per_class` images.
ValueError: If `batch_size` isn't divisible by `num_classes`.
Returns:
A single image.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size < num_classes * num_images_per_class:
raise ValueError('Not enough images in batch to show the desired number of '
'images.')
if batch_size % num_classes != 0:
raise ValueError('`batch_size` must be divisible by `num_classes`.')
# Only get a certain number of images per class.
num_batches = batch_size // num_classes
indices = [i * num_batches + j for i in xrange(num_classes)
for j in xrange(num_images_per_class)]
sampled_images = tf.gather(images, indices)
return tfgan.eval.image_reshaper(
sampled_images, num_cols=num_images_per_class)
def get_inception_scores(images, batch_size, num_inception_images):
"""Get Inception score for some images.
Args:
images: Image minibatch. Shape [batch size, width, height, channels]. Values
are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Inception scores. Tensor shape is [batch size].
Raises:
ValueError: If `batch_size` is incompatible with the first dimension of
`images`.
ValueError: If `batch_size` isn't divisible by `num_inception_images`.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size % num_inception_images != 0:
raise ValueError(
'`batch_size` must be divisible by `num_inception_images`.')
# Resize images.
size = 299
resized_images = tf.image.resize_bilinear(images, [size, size])
# Run images through Inception.
num_batches = batch_size // num_inception_images
inc_score = tfgan.eval.inception_score(
resized_images, num_batches=num_batches)
return inc_score
def get_frechet_inception_distance(real_images, generated_images, batch_size,
num_inception_images):
"""Get Frechet Inception Distance between real and generated images.
Args:
real_images: Real images minibatch. Shape [batch size, width, height,
channels. Values are in [-1, 1].
generated_images: Generated images minibatch. Shape [batch size, width,
height, channels]. Values are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Frechet Inception distance. A floating-point scalar.
Raises:
ValueError: If the minibatch size is known at graph construction time, and
doesn't batch `batch_size`.
"""
# Validate input dimensions.
real_images.shape[0:1].assert_is_compatible_with([batch_size])
generated_images.shape[0:1].assert_is_compatible_with([batch_size])
# Resize input images.
size = 299
resized_real_images = tf.image.resize_bilinear(real_images, [size, size])
resized_generated_images = tf.image.resize_bilinear(
generated_images, [size, size])
# Compute Frechet Inception Distance.
num_batches = batch_size // num_inception_images
fid = tfgan.eval.frechet_inception_distance(
resized_real_images, resized_generated_images, num_batches=num_batches)
return fid
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment