Unverified Commit 00fa8b12 authored by cclauss's avatar cclauss Committed by GitHub
Browse files

Merge branch 'master' into patch-13

parents 6d257a4f 1f34fcaf
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to run training for sequential latent variable models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import runners
# Shared flags.
tf.app.flags.DEFINE_string("mode", "train",
"The mode of the binary. Must be 'train' or 'test'.")
tf.app.flags.DEFINE_string("model", "vrnn",
"Model choice. Currently only 'vrnn' is supported.")
tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.")
tf.app.flags.DEFINE_string("dataset_type", "pianoroll",
"The type of dataset, either 'pianoroll' or 'speech'.")
tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None,
"The dimension of each vector in the data sequence. "
"Defaults to 88 for pianoroll datasets and 200 for speech "
"datasets. Should not need to be changed except for "
"testing.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample "
"algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.")
# Training flags.
tf.app.flags.DEFINE_string("bound", "fivo",
"The bound to optimize. Can be 'elbo', 'iwae', or 'fivo'.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps "
"per sequence.")
tf.app.flags.DEFINE_float("learning_rate", 0.0002,
"The learning rate for ADAM.")
tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.")
# Distributed training flags.
tf.app.flags.DEFINE_string("master", "",
"The BNS name of the TensorFlow master to use.")
tf.app.flags.DEFINE_integer("task", 0,
"Task id of the replica running the training.")
tf.app.flags.DEFINE_integer("ps_tasks", 0,
"Number of tasks in the ps job. If 0 no ps job is used.")
tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.")
# Evaluation flags.
tf.app.flags.DEFINE_string("split", "train",
"Split to evaluate the model on. Can be 'train', 'valid', or 'test'.")
FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88
SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.data_dimension is None:
if FLAGS.dataset_type == "pianoroll":
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
elif FLAGS.dataset_type == "speech":
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
if FLAGS.mode == "train":
runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VRNN classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
class VRNNCell(snt.AbstractModule):
"""Implementation of a Variational Recurrent Neural Network (VRNN).
Introduced in "A Recurrent Latent Variable Model for Sequential data"
by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
The VRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
The VRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
is conditioned on the previous sequence element, x_{t-1}, as well as the
latent state from the previous timestep, z_{t-1}.
In this implementation of the VRNN the latent state z_t is Gaussian. The
model's prior over z_t is distributed as Normal(mu_t, diag(sigma_t^2)) where
mu_t and sigma_t are the mean and standard deviation output from a fully
connected network that accepts the rnn hidden state h_t as input.
The approximate posterior (also known as q or the encoder in the VAE
framework) is similar to the prior except that it is conditioned on the
current target, x_t, as well as h_t via a fully connected network.
This implementation uses the 'res_q' parameterization of the approximate
posterior, meaning that instead of directly predicting the mean of z_t, the
approximate posterior predicts the 'residual' from the prior's mean. This is
explored more in section 3.3 of https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the approximate
posterior and the reparameterization trick is used to provide low-variance
gradients.
The generative distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the VRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_feat_extractor and latent_feat_extractor respectively.
There are a few differences between this exposition and the paper.
First, the indexing scheme for h_t is different than the paper's -- what the
paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
uses VAE terminology to refer to the different internal networks, so it
refers to the approximate posterior as the encoder and the generative
distribution as the decoder. This implementation also renamed the functions
phi_x and phi_z in the paper to data_feat_extractor and latent_feat_extractor.
"""
def __init__(self,
rnn_cell,
data_feat_extractor,
latent_feat_extractor,
prior,
approx_posterior,
generative,
random_seed=None,
name="vrnn"):
"""Creates a VRNN cell.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the VRNN. The inputs to the RNN will be the
encoded latent state of the previous timestep with shape
[batch_size, encoded_latent_size] as well as the encoded input of the
current timestep, a Tensor of shape [batch_size, encoded_data_size].
data_feat_extractor: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the VRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_feat_extractor: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
prior: A callable that implements the prior p(z_t|h_t). Must accept as
argument the previous RNN hidden state and return a
tf.contrib.distributions.Normal distribution conditioned on the input.
approx_posterior: A callable that implements the approximate posterior
q(z_t|h_t,x_t). Must accept as arguments the encoded target of the
current timestep and the previous RNN hidden state. Must return
a tf.contrib.distributions.Normal distribution conditioned on the
inputs.
generative: A callable that implements the generative distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.contrib.distributions.Distribution that can be used to evaluate
the logprob of the targets.
random_seed: The seed for the random ops. Used mainly for testing.
name: The name of this VRNN.
"""
super(VRNNCell, self).__init__(name=name)
self.rnn_cell = rnn_cell
self.data_feat_extractor = data_feat_extractor
self.latent_feat_extractor = latent_feat_extractor
self.prior = prior
self.approx_posterior = approx_posterior
self.generative = generative
self.random_seed = random_seed
self.encoded_z_size = latent_feat_extractor.output_size
self.state_size = (self.rnn_cell.state_size, self.encoded_z_size)
def zero_state(self, batch_size, dtype):
"""The initial state of the VRNN.
Contains the initial state of the RNN as well as a vector of zeros
corresponding to z_0.
Args:
batch_size: The batch size.
dtype: The data type of the VRNN.
Returns:
zero_state: The initial state of the VRNN.
"""
return (self.rnn_cell.zero_state(batch_size, dtype),
tf.zeros([batch_size, self.encoded_z_size], dtype=dtype))
def _build(self, observations, state, mask):
"""Computes one timestep of the VRNN.
Args:
observations: The observations at the current timestep, a tuple
containing the model inputs and targets as Tensors of shape
[batch_size, data_size].
state: The current state of the VRNN
mask: Tensor of shape [batch_size], 1.0 if the current timestep is active
active, 0.0 if it is not active.
Returns:
log_q_z: The logprob of the latent state according to the approximate
posterior.
log_p_z: The logprob of the latent state according to the prior.
log_p_x_given_z: The conditional log-likelihood, i.e. logprob of the
observation according to the generative distribution.
kl: The analytic kl divergence from q(z) to p(z).
state: The new state of the VRNN.
"""
inputs, targets = observations
rnn_state, prev_latent_encoded = state
# Encode the data.
inputs_encoded = self.data_feat_extractor(inputs)
targets_encoded = self.data_feat_extractor(targets)
# Run the RNN cell.
rnn_inputs = tf.concat([inputs_encoded, prev_latent_encoded], axis=1)
rnn_out, new_rnn_state = self.rnn_cell(rnn_inputs, rnn_state)
# Create the prior and approximate posterior distributions.
latent_dist_prior = self.prior(rnn_out)
latent_dist_q = self.approx_posterior(rnn_out, targets_encoded,
prior_mu=latent_dist_prior.loc)
# Sample the new latent state z and encode it.
latent_state = latent_dist_q.sample(seed=self.random_seed)
latent_encoded = self.latent_feat_extractor(latent_state)
# Calculate probabilities of the latent state according to the prior p
# and approximate posterior q.
log_q_z = tf.reduce_sum(latent_dist_q.log_prob(latent_state), axis=-1)
log_p_z = tf.reduce_sum(latent_dist_prior.log_prob(latent_state), axis=-1)
analytic_kl = tf.reduce_sum(
tf.contrib.distributions.kl_divergence(
latent_dist_q, latent_dist_prior),
axis=-1)
# Create the generative dist. and calculate the logprob of the targets.
generative_dist = self.generative(latent_encoded, rnn_out)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(targets), axis=-1)
return (log_q_z, log_p_z, log_p_x_given_z, analytic_kl,
(new_rnn_state, latent_encoded))
_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
def create_vrnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=None,
fcnet_hidden_sizes=None,
encoded_data_size=None,
encoded_latent_size=None,
sigma_min=0.0,
raw_sigma_bias=0.25,
generative_bias_init=0.0,
initializers=None,
random_seed=None):
"""A factory method for creating VRNN cells.
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the VRNN.
generative_class: The class of the generative distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the VRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
generative_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the generative distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the VRNN resampling operations.
Returns:
model: A VRNNCell object.
"""
if rnn_hidden_size is None:
rnn_hidden_size = latent_size
if fcnet_hidden_sizes is None:
fcnet_hidden_sizes = [latent_size]
if encoded_data_size is None:
encoded_data_size = latent_size
if encoded_latent_size is None:
encoded_latent_size = latent_size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
data_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_data_size],
initializers=initializers,
name="data_feat_extractor")
latent_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
initializers=initializers,
name="latent_feat_extractor")
prior = ConditionalNormalDistribution(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="prior")
approx_posterior = NormalApproximatePosterior(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="approximate_posterior")
if generative_class == ConditionalBernoulliDistribution:
generative = ConditionalBernoulliDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
bias_init=generative_bias_init,
name="generative")
else:
generative = ConditionalNormalDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="generative")
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
return VRNNCell(rnn_cell, data_feat_extractor, latent_feat_extractor,
prior, approx_posterior, generative, random_seed=random_seed)
class ConditionalNormalDistribution(object):
"""A Normal distribution conditioned on Tensor inputs via a fc network."""
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, name="conditional_normal_distribution"):
"""Creates a conditional Normal distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
sigma_min: The minimum standard deviation allowed, a scalar.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the fully connected network. Set to 0.25 by default to
prevent standard deviations close to 0.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intitializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
name: The name of this distribution, used for sonnet scoping.
"""
self.sigma_min = sigma_min
self.raw_sigma_bias = raw_sigma_bias
self.name = name
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [2*size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list, **unused_kwargs):
"""Computes the parameters of a normal distribution based on the inputs."""
inputs = tf.concat(tensor_list, axis=1)
outs = self.fcnet(inputs)
mu, sigma = tf.split(outs, 2, axis=1)
sigma = tf.maximum(tf.nn.softplus(sigma + self.raw_sigma_bias),
self.sigma_min)
return mu, sigma
def __call__(self, *args, **kwargs):
"""Creates a normal distribution conditioned on the inputs."""
mu, sigma = self.condition(args, **kwargs)
return tf.contrib.distributions.Normal(loc=mu, scale=sigma)
class ConditionalBernoulliDistribution(object):
"""A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
def __init__(self, size, hidden_layer_sizes, hidden_activation_fn=tf.nn.relu,
initializers=None, bias_init=0.0,
name="conditional_bernoulli_distribution"):
"""Creates a conditional Bernoulli distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intiializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
bias_init: A scalar or vector Tensor that is added to the output of the
fully-connected network that parameterizes the mean of this
distribution.
name: The name of this distribution, used for sonnet scoping.
"""
self.bias_init = bias_init
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list):
"""Computes the p parameter of the Bernoulli distribution."""
inputs = tf.concat(tensor_list, axis=1)
return self.fcnet(inputs) + self.bias_init
def __call__(self, *args):
p = self.condition(args)
return tf.contrib.distributions.Bernoulli(logits=p)
class NormalApproximatePosterior(ConditionalNormalDistribution):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def condition(self, tensor_list, prior_mu):
"""Generates the mean and variance of the normal distribution.
Args:
tensor_list: The list of Tensors to condition on. Will be concatenated and
fed through a fully connected network.
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
return mu + prior_mu, sigma
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A set of utils for dealing with nested lists and tuples of Tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.util import nest
def map_nested(map_fn, nested):
"""Executes map_fn on every element in a (potentially) nested structure.
Args:
map_fn: A callable to execute on each element in 'nested'.
nested: A potentially nested combination of sequence objects. Sequence
objects include tuples, lists, namedtuples, and all subclasses of
collections.Sequence except strings. See nest.is_sequence for details.
For example [1, ('hello', 4.3)] is a nested structure containing elements
1, 'hello', and 4.3.
Returns:
out_structure: A potentially nested combination of sequence objects with the
same structure as the 'nested' input argument. out_structure
contains the result of applying map_fn to each element in 'nested'. For
example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
"""
out = map(map_fn, nest.flatten(nested))
return nest.pack_sequence_as(nested, out)
def tile_tensors(tensors, multiples):
"""Tiles a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors with rank
greater than or equal to the length of 'multiples'. The Tensors do not
need to have the same rank, but their rank must not be dynamic.
multiples: A python list of ints indicating how to tile each Tensor
in 'tensors'. Similar to the 'multiples' argument to tf.tile.
Returns:
tiled_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tensors' input argument. Contains the result of
applying tf.tile to each Tensor in 'tensors'. When the rank of a Tensor
in 'tensors' is greater than the length of multiples, multiples is padded
at the end with 1s. For example when tiling a 4-dimensional Tensor with
multiples [3, 4], multiples would be padded to [3, 4, 1, 1] before tiling.
"""
def tile_fn(x):
return tf.tile(x, multiples + [1]*(x.shape.ndims - len(multiples)))
return map_nested(tile_fn, tensors)
def gather_tensors(tensors, indices):
"""Performs a tf.gather operation on a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors.
indices: The indices to use for the gather operation.
Returns:
gathered_tensors: A potentially nested tuple or list of Tensors with the
same structure as the 'tensors' input argument. Contains the result of
applying tf.gather(x, indices) on each element x in 'tensors'.
"""
return map_nested(lambda x: tf.gather(x, indices), tensors)
def tas_for_tensors(tensors, length):
"""Unstacks a set of Tensors into TensorArrays.
Args:
tensors: A potentially nested tuple or list of Tensors with length in the
first dimension greater than or equal to the 'length' input argument.
length: The desired length of the TensorArrays.
Returns:
tensorarrays: A potentially nested tuple or list of TensorArrays with the
same structure as 'tensors'. Contains the result of unstacking each Tensor
in 'tensors'.
"""
def map_fn(x):
ta = tf.TensorArray(x.dtype, length, name=x.name.split(':')[0] + '_ta')
return ta.unstack(x[:length, :])
return map_nested(map_fn, tensors)
def read_tas(tas, index):
"""Performs a read operation on a set of TensorArrays.
Args:
tas: A potentially nested tuple or list of TensorArrays with length greater
than 'index'.
index: The location to read from.
Returns:
read_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tas' input argument. Contains the result of
performing a read operation at 'index' on each TensorArray in 'tas'.
"""
return map_nested(lambda ta: ta.read(index), tas)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""High-level code for creating and running FIVO-related Tensorflow graphs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import tensorflow as tf
import bounds
from data import datasets
from models import vrnn
def create_dataset_and_model(config, split, shuffle, repeat):
"""Creates the dataset and model for a given config.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. This function expects the properties
batch_size, dataset_path, dataset_type, and latent_size to be defined.
split: The dataset split to load.
shuffle: If true, shuffle the dataset randomly.
repeat: If true, repeat the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension].
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
model: A vrnn.VRNNCell model object.
"""
if config.dataset_type == "pianoroll":
inputs, targets, lengths, mean = datasets.create_pianoroll_dataset(
config.dataset_path, split, config.batch_size, shuffle=shuffle,
repeat=repeat)
# Convert the mean of the training set to logit space so it can be used to
# initialize the bias of the generative distribution.
generative_bias_init = -tf.log(
1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1)
generative_distribution_class = vrnn.ConditionalBernoulliDistribution
elif config.dataset_type == "speech":
inputs, targets, lengths = datasets.create_speech_dataset(
config.dataset_path, config.batch_size,
samples_per_timestep=config.data_dimension, prefetch_buffer_size=1,
shuffle=False, repeat=False)
generative_bias_init = None
generative_distribution_class = vrnn.ConditionalNormalDistribution
model = vrnn.create_vrnn(inputs.get_shape().as_list()[2],
config.latent_size,
generative_distribution_class,
generative_bias_init=generative_bias_init,
raw_sigma_bias=0.5)
return inputs, targets, lengths, model
def restore_checkpoint_if_exists(saver, sess, logdir):
"""Looks for a checkpoint and restores the session from it if found.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
Returns:
True if a checkpoint was found and restored, False otherwise.
"""
checkpoint = tf.train.get_checkpoint_state(logdir)
if checkpoint:
checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
full_checkpoint_path = os.path.join(logdir, checkpoint_name)
saver.restore(sess, full_checkpoint_path)
return True
return False
def wait_for_checkpoint(saver, sess, logdir):
"""Loops until the session is restored from a checkpoint in logdir.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
"""
while True:
if restore_checkpoint_if_exists(saver, sess, logdir):
break
else:
tf.logging.info("Checkpoint not found in %s, sleeping for 60 seconds."
% logdir)
time.sleep(60)
def run_train(config):
"""Runs training for a sequential latent variable model.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_logging_hook(step, bound_value):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + " bound"
if config.normalize_by_seq_len:
bound_label += " per timestep"
else:
bound_label += " per sequence"
def summary_formatter(log_dict):
return "Step %d, %s: %f" % (
log_dict["step"], bound_label, log_dict["bound_value"])
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "bound_value": bound_value},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_loss():
"""Creates the loss to be optimized.
Returns:
bound: A float Tensor containing the value of the bound that is
being optimized.
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
"""
inputs, targets, lengths, model = create_dataset_and_model(
config, split="train", shuffle=True, repeat=True)
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
elif config.bound == "iwae":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
elif config.bound == "fivo":
ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
# Compute loss scaled by number of timesteps.
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
ll_per_seq = tf.reduce_mean(ll_per_seq)
tf.summary.scalar("train_ll_per_seq", ll_per_seq)
tf.summary.scalar("train_ll_per_t", ll_per_t)
if config.normalize_by_seq_len:
return ll_per_t, -ll_per_t
else:
return ll_per_seq, -ll_per_seq
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
bound, loss = create_loss()
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, train_op, global_step
device = tf.train.replica_device_setter(ps_tasks=config.ps_tasks)
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
with tf.device(device):
bound, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound)
start_training = not config.stagger_workers
with tf.train.MonitoredTrainingSession(
master=config.master,
is_chief=config.task == 0,
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > config.max_steps: break
if config.task > 0 and not start_training:
cur_step = sess.run(global_step)
tf.logging.info("task %d not active yet, sleeping at step %d" %
(config.task, cur_step))
time.sleep(30)
if cur_step >= config.task * 1000:
start_training = True
else:
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Runs evaluation for a sequential latent variable model.
This method runs only one evaluation over the dataset, writes summaries to
disk, and then terminates. It does not loop indefinitely.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_graph():
"""Creates the evaluation graph.
Returns:
lower_bounds: A tuple of float Tensors containing the values of the 3
evidence lower bounds, summed across the batch.
total_batch_length: The total number of timesteps in the batch, summed
across batch examples.
batch_size: The batch size.
global_step: The global step the checkpoint was loaded from.
"""
global_step = tf.train.get_or_create_global_step()
inputs, targets, lengths, model = create_dataset_and_model(
config, split=config.split, shuffle=False, repeat=False)
# Compute lower bounds on the log likelihood.
elbo_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
iwae_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
fivo_ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
batch_size = tf.shape(lengths)[0]
total_batch_length = tf.reduce_sum(lengths)
return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
global_step)
def average_bounds_over_dataset(lower_bounds, total_batch_length, batch_size,
sess):
"""Computes the values of the bounds, averaged over the datset.
Args:
lower_bounds: Tuple of float Tensors containing the values of the bounds
evaluated on a single batch.
total_batch_length: Integer Tensor that represents the total number of
timesteps in the current batch.
batch_size: Integer Tensor containing the batch size. This can vary if the
requested batch_size does not evenly divide the size of the dataset.
sess: A TensorFlow Session object.
Returns:
ll_per_t: A length 3 numpy array of floats containing each bound's average
value, normalized by the total number of timesteps in the datset. Can
be interpreted as a lower bound on the average log likelihood per
timestep in the dataset.
ll_per_seq: A length 3 numpy array of floats containing each bound's
average value, normalized by the number of sequences in the dataset.
Can be interpreted as a lower bound on the average log likelihood per
sequence in the datset.
"""
total_ll = np.zeros(3, dtype=np.float64)
total_n_elems = 0.0
total_length = 0.0
while True:
try:
outs = sess.run([lower_bounds, batch_size, total_batch_length])
except tf.errors.OutOfRangeError:
break
total_ll += outs[0]
total_n_elems += outs[1]
total_length += outs[2]
ll_per_t = total_ll / total_length
ll_per_seq = total_ll / total_n_elems
return ll_per_t, ll_per_seq
def summarize_lls(lls_per_t, lls_per_seq, summary_writer, step):
"""Creates log-likelihood lower bound summaries and writes them to disk.
Args:
lls_per_t: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of timesteps.
lls_per_seq: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of sequences.
summary_writer: A tf.SummaryWriter.
step: The current global step.
"""
def scalar_summary(name, value):
value = tf.Summary.Value(tag=name, simple_value=value)
return tf.Summary(value=[value])
for i, bound in enumerate(["elbo", "iwae", "fivo"]):
per_t_summary = scalar_summary("%s/%s_ll_per_t" % (config.split, bound),
lls_per_t[i])
per_seq_summary = scalar_summary("%s/%s_ll_per_seq" %
(config.split, bound),
lls_per_seq[i])
summary_writer.add_summary(per_t_summary, global_step=step)
summary_writer.add_summary(per_seq_summary, global_step=step)
summary_writer.flush()
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
lower_bounds, total_batch_length, batch_size, global_step = create_graph()
summary_dir = config.logdir + "/" + config.split
summary_writer = tf.summary.FileWriter(
summary_dir, flush_secs=15, max_queue=100)
saver = tf.train.Saver()
with tf.train.SingularMonitoredSession() as sess:
wait_for_checkpoint(saver, sess, config.logdir)
step = sess.run(global_step)
tf.logging.info("Model restored from step %d, evaluating." % step)
ll_per_t, ll_per_seq = average_bounds_over_dataset(
lower_bounds, total_batch_length, batch_size, sess)
summarize_lls(ll_per_t, ll_per_seq, summary_writer, step)
tf.logging.info("%s elbo ll/t: %f, iwae ll/t: %f fivo ll/t: %f",
config.split, ll_per_t[0], ll_per_t[1], ll_per_t[2])
tf.logging.info("%s elbo ll/seq: %f, iwae ll/seq: %f fivo ll/seq: %f",
config.split, ll_per_seq[0], ll_per_seq[1], ll_per_seq[2])
...@@ -25,7 +25,8 @@ Maintainers of TFGAN: ...@@ -25,7 +25,8 @@ Maintainers of TFGAN:
1. [Image compression (coming soon)](#compression) 1. [Image compression (coming soon)](#compression)
## MNIST {#mnist} ## MNIST
<a id='mnist'></a>
We train a simple generator to produce [MNIST digits](http://yann.lecun.com/exdb/mnist/). We train a simple generator to produce [MNIST digits](http://yann.lecun.com/exdb/mnist/).
The unconditional case maps noise to MNIST digits. The conditional case maps The unconditional case maps noise to MNIST digits. The conditional case maps
...@@ -36,22 +37,24 @@ network architectures are defined [here](https://github.com/tensorflow/models/tr ...@@ -36,22 +37,24 @@ network architectures are defined [here](https://github.com/tensorflow/models/tr
We use a classifier trained on MNIST digit classification for evaluation. We use a classifier trained on MNIST digit classification for evaluation.
### Unconditional MNIST ### Unconditional MNIST
![Unconditional GAN](g3doc/mnist_unconditional_gan.png "unconditional GAN") <img src="g3doc/mnist_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional MNIST ### Conditional MNIST
![Conditional GAN](g3doc/mnist_conditional_gan.png "conditional GAN") <img src="g3doc/mnist_conditional_gan.png" title="Conditional GAN" width="330" />
### InfoGAN MNIST ### InfoGAN MNIST
![InfoGAN](g3doc/mnist_infogan.png "InfoGAN") <img src="g3doc/mnist_infogan.png" title="InfoGAN" width="330" />
## MNIST with GANEstimator {#mnist_estimator} ## MNIST with GANEstimator
<a id='mnist_estimator'></a>
This setup is exactly the same as in the [unconditional MNIST example](#mnist), but This setup is exactly the same as in the [unconditional MNIST example](#mnist), but
uses the `tf.Learn` `GANEstimator`. uses the `tf.Learn` `GANEstimator`.
![Unconditional GAN](g3doc/mnist_estimator_unconditional_gan.png "unconditional GAN") <img src="g3doc/mnist_estimator_unconditional_gan.png" title="Unconditional GAN" width="330" />
## CIFAR10 {#cifar10} ## CIFAR10
<a id='cifar10'></a>
We train a [DCGAN generator](https://arxiv.org/abs/1511.06434) to produce [CIFAR10 images](https://www.cs.toronto.edu/~kriz/cifar.html). We train a [DCGAN generator](https://arxiv.org/abs/1511.06434) to produce [CIFAR10 images](https://www.cs.toronto.edu/~kriz/cifar.html).
The unconditional case maps noise to CIFAR10 images. The conditional case maps The unconditional case maps noise to CIFAR10 images. The conditional case maps
...@@ -61,12 +64,13 @@ network architectures are defined [here](https://github.com/tensorflow/models/tr ...@@ -61,12 +64,13 @@ network architectures are defined [here](https://github.com/tensorflow/models/tr
We use the [Inception Score](https://arxiv.org/abs/1606.03498) to evaluate the images. We use the [Inception Score](https://arxiv.org/abs/1606.03498) to evaluate the images.
### Unconditional CIFAR10 ### Unconditional CIFAR10
![Unconditional GAN](g3doc/cifar_unconditional_gan.png "unconditional GAN") <img src="g3doc/cifar_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional CIFAR10 ### Conditional CIFAR10
![Unconditional GAN](g3doc/cifar_conditional_gan.png "unconditional GAN"){width="330"} <img src="g3doc/cifar_conditional_gan.png" title="Conditional GAN" width="330" />
## Image compression {#compression} ## Image compression
<a id='compression'></a>
In neural image compression, we attempt to reduce an image to a smaller representation In neural image compression, we attempt to reduce an image to a smaller representation
such that we can recreate the original image as closely as possible. See [`Full Resolution Image Compression with Recurrent Neural Networks`](https://arxiv.org/abs/1608.05148) for more details on using neural networks such that we can recreate the original image as closely as possible. See [`Full Resolution Image Compression with Recurrent Neural Networks`](https://arxiv.org/abs/1608.05148) for more details on using neural networks
...@@ -93,12 +97,10 @@ Some other notes on the problem: ...@@ -93,12 +97,10 @@ Some other notes on the problem:
### Results ### Results
#### No adversarial loss #### No adversarial loss
<img src="g3doc/compression_wf0.png" title="No adversarial loss" width="500" />
![compresson_no_adversarial](g3doc/compression_wf0.png "no adversarial loss")
#### Adversarial loss #### Adversarial loss
<img src="g3doc/compression_wf10000.png" title="With adversarial loss" width="500" />
![compresson_no_adversarial](g3doc/compression_wf10000.png "with adversarial loss")
### Architectures ### Architectures
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
......
...@@ -37,26 +37,25 @@ class UtilTest(tf.test.TestCase): ...@@ -37,26 +37,25 @@ class UtilTest(tf.test.TestCase):
num_classes=3, num_classes=3,
num_images_per_class=1) num_images_per_class=1)
def test_get_inception_scores(self): # Mock `inception_score` which is expensive.
# Mock `inception_score` which is expensive. @mock.patch.object(util.tfgan.eval, 'inception_score', autospec=True)
with mock.patch.object( def test_get_inception_scores(self, mock_inception_score):
util.tfgan.eval, 'inception_score') as mock_inception_score: mock_inception_score.return_value = 1.0
mock_inception_score.return_value = 1.0 util.get_inception_scores(
util.get_inception_scores( tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]), batch_size=100,
batch_size=100, num_inception_images=10)
num_inception_images=10)
# Mock `frechet_inception_distance` which is expensive.
def test_get_frechet_inception_distance(self): @mock.patch.object(util.tfgan.eval, 'frechet_inception_distance',
# Mock `frechet_inception_distance` which is expensive. autospec=True)
with mock.patch.object( def test_get_frechet_inception_distance(self, mock_fid):
util.tfgan.eval, 'frechet_inception_distance') as mock_fid: mock_fid.return_value = 1.0
mock_fid.return_value = 1.0 util.get_frechet_inception_distance(
util.get_frechet_inception_distance( tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]), tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]), batch_size=100,
batch_size=100, num_inception_images=10)
num_inception_images=10)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the compression image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(split_name, batch_size, dataset_dir,
dataset_name='imagenet', num_readers=1, num_threads=1,
patch_size=128):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize = split_name == 'train'
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=randomize)
[image] = provider.get(['image'])
# Sample a patch of fixed size.
patch = tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size)
patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch = (tf.to_float(patch) - 128.0) / 142.0
if randomize:
image_batch = tf.train.shuffle_batch(
[patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
image_batch = tf.train.batch(
[patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
return image_batch
def float_image_to_uint8(image):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 142.0) + 128.0
return tf.cast(image, tf.uint8)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def _test_data_provider_helper(self, split_name):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3
patch_size = 8
images = data_provider.provide_data(
split_name, batch_size, dataset_dir, patch_size=8)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_data_provider_train(self):
self._test_data_provider_helper('train')
def test_data_provider_validation(self):
self._test_data_provider_helper('validation')
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained compression model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import summaries
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/compression/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/compression/',
'Directory where the results are saved to.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
# Compression-specific flags.
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
flags.DEFINE_integer('bits_per_patch', 1230,
'The number of bits to produce per patch.')
flags.DEFINE_integer('model_depth', 64,
'Number of filters for compression model')
def main(_, run_eval_loop=True):
with tf.name_scope('inputs'):
images = data_provider.provide_data(
'validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('generator'):
reconstructions, _, prebinary = networks.compression_model(
images,
num_bits=FLAGS.bits_per_patch,
depth=FLAGS.model_depth,
is_training=False)
summaries.add_reconstruction_summaries(images, reconstructions, prebinary)
# Visualize losses.
pixel_loss_per_example = tf.reduce_mean(
tf.abs(images - reconstructions), axis=[1, 2, 3])
pixel_loss = tf.reduce_mean(pixel_loss_per_example)
tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
tf.summary.scalar('pixel_l1_loss', pixel_loss)
# Create ops to write images to disk.
uint8_images = data_provider.float_image_to_uint8(images)
uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'compression.png'),
tf.image.encode_png(uint8_reshaped[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
class EvalTest(tf.test.TestCase):
def test_build_graph(self):
eval.main(None, run_eval_loop=False)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment