Commit 68af8a99 authored by Dieterich Lawson's avatar Dieterich Lawson
Browse files

Adding FIVO folder to research

parent 6e4bbb74
......@@ -35,6 +35,7 @@ research/syntaxnet/* @calberti @andorardo @bogatyy @markomernick
research/textsum/* @panyx0718 @peterjliu
research/transformer/* @daviddao
research/video_prediction/* @cbfinn
research/fivo/* @dieterichlawson
samples/* @MarkDaoust
tutorials/embedding/* @zffchen78 @a-dai
tutorials/image/* @sherrym @shlens
......
# Filtering Variational Objectives
This folder contains a TensorFlow implementation of the algorithms from
Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. "Filtering Variational Objectives." NIPS 2017.
[https://arxiv.org/abs/1705.09279](https://arxiv.org/abs/1705.09279)
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
#### Directory Structure
The important parts of the code are organized as follows.
```
fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation
bounds.py # code for computing each bound
data
├── datasets.py # readers for pianoroll and speech datasets
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
└── create_timit_dataset.py # preprocesses the TIMIT dataset
models
└── vrnn.py # variational RNN implementation
bin
├── run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files
```
### Training on Pianorolls
Requirements before we start:
* TensorFlow (see [tensorflow.org](http://tensorflow.org) for how to install)
* [scipy](https://www.scipy.org/)
* [sonnet](https://github.com/deepmind/sonnet)
#### Download the Data
The pianoroll datasets are encoded as pickled sparse arrays and are available at [http://www-etud.iro.umontreal.ca/~boulanni/icml2012](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). You can use the script `bin/download_pianorolls.sh` to download the files into a directory of your choosing.
```
export PIANOROLL_DIR=~/pianorolls
mkdir $PIANOROLL_DIR
sh bin/download_pianorolls.sh $PIANOROLL_DIR
```
#### Preprocess the Data
The script `calculate_pianoroll_mean.py` loads a pianoroll pickle file, calculates the mean, updates the pickle file to include the mean under the key `train_mean`, and writes the file back to disk in-place. You should do this for all pianoroll datasets you wish to train on.
```
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/piano-midi.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/nottingham.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/musedata.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
```
#### Training
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
```
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output that looks something like this (with a lot of extra logging cruft):
```
Step 1, fivo bound per timestep: -11.801050
global_step/sec: 9.89825
Step 101, fivo bound per timestep: -11.198309
global_step/sec: 9.55475
Step 201, fivo bound per timestep: -11.287262
global_step/sec: 9.68146
step 301, fivo bound per timestep: -11.316490
global_step/sec: 9.94295
Step 401, fivo bound per timestep: -11.151743
```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
```
python fivo.py \
--mode=eval \
--split=test \
--alsologtostderr \
--logdir=/tmp/fivo \
--model=vrnn \
--batch_size=4 \
--num_samples=4 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output like this:
```
Model restored from step 1, evaluating.
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
#### Preprocess TIMIT
We preprocess TIMIT (as described in our paper) and write it out to a series of TFRecord files. To prepare the TIMIT dataset use the script `create_timit_dataset.py`
```
export $TIMIT_DIR=~/timit_dataset
mkdir $TIMIT_DIR
python data/create_timit_dataset.py \
--raw_timit_dir=$RAW_TIMIT_DIR \
--out_dir=$TIMIT_DIR
```
You should see this exact output:
```
4389 train / 231 valid / 1680 test
train mean: 0.006060 train std: 548.136169
```
#### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched.
```
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech"
```
### Contact
This codebase is maintained by Dieterich Lawson, reachable via email at dieterichl@google.com. For questions and issues please open an issue on the tensorflow/models issues tracker and assign it to @dieterich.lawson.
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# A script to download the pianoroll datasets.
# Accepts one argument, the directory to put the files in
if [ -z "$1" ]
then
echo "Error, must provide a directory to download the files to."
exit
fi
echo "Downloading datasets into $1"
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle" > $1/piano-midi.de.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle" > $1/nottingham.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle" > $1/musedata.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle" > $1/jsb.pkl
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of running evaluation.
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
--mode=eval \
--logdir=/tmp/fivo \
--model=vrnn \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of running training.
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of objectives for training stochastic latent variable models.
Contains implementations of the Importance Weighted Autoencoder objective (IWAE)
and the Filtering Variational objective (FIVO).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import nested_utils as nested
def iwae(cell,
inputs,
seq_lengths,
num_samples=1,
parallel_iterations=30,
swap_memory=True):
"""Computes the IWAE lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The IWAE estimator is defined by averaging multiple importance
weights. For more details see "Importance Weighted Autoencoders" by Burda
et al. https://arxiv.org/abs/1509.00519.
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
parallel_iterations: The number of parallel iterations to use for the
internal while loop.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the kl divergence
from q(z|x) to p(z), averaged over samples.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size at each timestep. Will not be valid for timesteps
past the end of a sequence.
"""
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of IWAE computation."""
log_weights_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Update the Tensorarrays and accumulators.
ta_updates = [log_weights_acc, log_ess]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
new_accs = (log_weights_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess = [x.stack() for x in tas]
final_log_weights, kl = accs
log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess
def ess_criterion(num_samples, log_ess, unused_t):
"""A criterion that resamples based on effective sample size."""
return log_ess <= tf.log(num_samples / 2.0)
def never_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion that never resamples."""
return tf.cast(tf.zeros_like(log_ess), tf.bool)
def always_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion resamples at every timestep."""
return tf.cast(tf.ones_like(log_ess), tf.bool)
def fivo(cell,
inputs,
seq_lengths,
num_samples=1,
resampling_criterion=ess_criterion,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Computes the FIVO lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The lower bound is defined by a particle filter's unbiased
estimate of the marginal probability of the observations. For more details see
"Filtering Variational Objectives" by Maddison et al.
https://arxiv.org/abs/1705.09279.
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the effective sample size,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions defined in this file for examples.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
random_seed: The random seed to pass to the resampling operations in
the particle filter. Mainly useful for testing.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the sum over time of the kl
divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
this includes kl terms from trajectories that are culled during resampling
steps.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size of each particle filter at each timestep. Will not
be valid for timesteps past the end of a sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size represents the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess', 'resampled']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
log_p_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, log_p_hat_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of FIVO computation."""
log_weights_acc, log_p_hat_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Calculate the ancestor indices via resampling. Because we maintain the
# log unnormalized weights, we pass the weights in as logits, allowing
# the distribution object to apply a softmax and normalize them.
resampling_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weights_acc, perm=[1, 0]))
ancestor_inds = tf.stop_gradient(
resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
# Because the batch is flattened and laid out as discussed
# above, we must modify ancestor_inds to index the proper samples.
# The particles in the ith filter are distributed every batch_size rows
# in the batch, and offset i rows from the top. So, to correct the indices
# we multiply by the batch_size and add the proper offset. Crucially,
# when ancestor_inds is flattened the layout of the batch is maintained.
offset = tf.expand_dims(tf.range(batch_size), 0)
ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
noresample_inds = tf.range(num_samples * batch_size)
# Decide whether or not we should resample; don't resample if we are past
# the end of a sequence.
should_resample = resampling_criterion(num_samples, log_ess, t)
should_resample = tf.logical_and(should_resample,
cur_mask[:batch_size] > 0.)
float_should_resample = tf.to_float(should_resample)
ancestor_inds = tf.where(
tf.tile(should_resample, [num_samples]),
ancestor_inds,
noresample_inds)
new_state = nested.gather_tensors(new_state, ancestor_inds)
# Update the TensorArrays before we reset the weights so that we capture
# the incremental weights and not zeros.
ta_updates = [log_weights_acc, log_ess, float_should_resample]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
# For the particle filters that resampled, update log_p_hat and
# reset weights to zero.
log_p_hat_update = tf.reduce_logsumexp(
log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
log_p_hat_acc += log_p_hat_update * float_should_resample
log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :],
[num_samples, 1]))
new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess, resampled = [x.stack() for x in tas]
final_log_weights, log_p_hat, kl = accs
# Add in the final weight update to log_p_hat.
log_p_hat += (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess, resampled
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to calculate the mean of a pianoroll dataset.
Given a pianoroll pickle file, this script loads the dataset and
calculates the mean of the training set. Then it updates the pickle file
so that the key "train_mean" points to the mean vector.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import numpy as np
import tensorflow as tf
from datasets import sparse_pianoroll_to_dense
tf.app.flags.DEFINE_string('in_file', None,
'Filename of the pickled pianoroll dataset to load.')
tf.app.flags.DEFINE_string('out_file', None,
'Name of the output pickle file. Defaults to in_file, '
'updating the input pickle file.')
tf.app.flags.mark_flag_as_required('in_file')
FLAGS = tf.app.flags.FLAGS
MIN_NOTE = 21
MAX_NOTE = 108
NUM_NOTES = MAX_NOTE - MIN_NOTE + 1
def main(unused_argv):
if FLAGS.out_file is None:
FLAGS.out_file = FLAGS.in_file
with tf.gfile.Open(FLAGS.in_file, 'r') as f:
pianorolls = pickle.load(f)
dense_pianorolls = [sparse_pianoroll_to_dense(p, MIN_NOTE, NUM_NOTES)[0]
for p in pianorolls['train']]
# Concatenate all elements along the time axis.
concatenated = np.concatenate(dense_pianorolls, axis=0)
mean = np.mean(concatenated, axis=0)
pianorolls['train_mean'] = mean
# Write out the whole pickle file, including the train mean.
pickle.dump(pianorolls, open(FLAGS.out_file, 'wb'))
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import random
import re
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string("raw_timit_dir", None,
"Directory containing TIMIT files.")
tf.app.flags.DEFINE_string("out_dir", None,
"Output directory for TFRecord files.")
tf.app.flags.DEFINE_float("valid_frac", 0.05,
"Fraction of train set to use as valid set. "
"Must be between 0.0 and 1.0.")
tf.app.flags.mark_flag_as_required("raw_timit_dir")
tf.app.flags.mark_flag_as_required("out_dir")
FLAGS = tf.app.flags.FLAGS
NUM_TRAIN_FILES = 4620
NUM_TEST_FILES = 1680
SAMPLES_PER_TIMESTEP = 200
# Regexes for reading SPHERE header files.
SAMPLE_COUNT_REGEX = re.compile(r"sample_count -i (\d+)")
SAMPLE_MIN_REGEX = re.compile(r"sample_min -i (-?\d+)")
SAMPLE_MAX_REGEX = re.compile(r"sample_max -i (-?\d+)")
def get_filenames(split):
"""Get all wav filenames from the TIMIT archive."""
path = os.path.join(FLAGS.raw_timit_dir, "TIMIT", split, "*", "*", "*.WAV")
# Sort the output by name so the order is deterministic.
files = sorted(glob.glob(path))
return files
def load_timit_wav(filename):
"""Loads a TIMIT wavfile into a numpy array.
TIMIT wavfiles include a SPHERE header, detailed in the TIMIT docs. The first
line is the header type and the second is the length of the header in bytes.
After the header, the remaining bytes are actual WAV data.
The header includes information about the WAV data such as the number of
samples and minimum and maximum amplitude. This function asserts that the
loaded wav data matches the header.
Args:
filename: The name of the TIMIT wavfile to load.
Returns:
wav: A numpy array containing the loaded wav data.
"""
wav_file = open(filename, "rb")
header_type = wav_file.readline()
header_length_str = wav_file.readline()
# The header length includes the length of the first two lines.
header_remaining_bytes = (int(header_length_str) - len(header_type) -
len(header_length_str))
header = wav_file.read(header_remaining_bytes)
# Read the relevant header fields.
sample_count = int(SAMPLE_COUNT_REGEX.search(header).group(1))
sample_min = int(SAMPLE_MIN_REGEX.search(header).group(1))
sample_max = int(SAMPLE_MAX_REGEX.search(header).group(1))
wav = np.fromstring(wav_file.read(), dtype="int16").astype("float32")
# Check that the loaded data conforms to the header description.
assert len(wav) == sample_count
assert wav.min() == sample_min
assert wav.max() == sample_max
return wav
def preprocess(wavs, block_size, mean, std):
"""Normalize the wav data and reshape it into chunks."""
processed_wavs = []
for wav in wavs:
wav = (wav - mean) / std
wav_length = wav.shape[0]
if wav_length % block_size != 0:
pad_width = block_size - (wav_length % block_size)
wav = np.pad(wav, (0, pad_width), "constant")
assert wav.shape[0] % block_size == 0
wav = wav.reshape((-1, block_size))
processed_wavs.append(wav)
return processed_wavs
def create_tfrecord_from_wavs(wavs, output_file):
"""Writes processed wav files to disk as sharded TFRecord files."""
with tf.python_io.TFRecordWriter(output_file) as builder:
for wav in wavs:
builder.write(wav.astype(np.float32).tobytes())
def main(unused_argv):
train_filenames = get_filenames("TRAIN")
test_filenames = get_filenames("TEST")
num_train_files = len(train_filenames)
num_test_files = len(test_filenames)
num_valid_files = int(num_train_files * FLAGS.valid_frac)
num_train_files -= num_valid_files
print("%d train / %d valid / %d test" % (
num_train_files, num_valid_files, num_test_files))
random.seed(1234)
random.shuffle(train_filenames)
valid_filenames = train_filenames[:num_valid_files]
train_filenames = train_filenames[num_valid_files:]
# Make sure there is no overlap in the train, test, and valid sets.
train_s = set(train_filenames)
test_s = set(test_filenames)
valid_s = set(valid_filenames)
# Disable explicit length testing to make the assertions more readable.
# pylint: disable=g-explicit-length-test
assert len(train_s & test_s) == 0
assert len(train_s & valid_s) == 0
assert len(valid_s & test_s) == 0
# pylint: enable=g-explicit-length-test
train_wavs = [load_timit_wav(f) for f in train_filenames]
valid_wavs = [load_timit_wav(f) for f in valid_filenames]
test_wavs = [load_timit_wav(f) for f in test_filenames]
assert len(train_wavs) + len(valid_wavs) == NUM_TRAIN_FILES
assert len(test_wavs) == NUM_TEST_FILES
# Calculate the mean and standard deviation of the train set.
train_stacked = np.hstack(train_wavs)
train_mean = np.mean(train_stacked)
train_std = np.std(train_stacked)
print("train mean: %f train std: %f" % (train_mean, train_std))
# Process all data, normalizing with the train set statistics.
processed_train_wavs = preprocess(train_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_valid_wavs = preprocess(valid_wavs, SAMPLES_PER_TIMESTEP,
train_mean, train_std)
processed_test_wavs = preprocess(test_wavs, SAMPLES_PER_TIMESTEP, train_mean,
train_std)
# Write the datasets to disk.
create_tfrecord_from_wavs(
processed_train_wavs,
os.path.join(FLAGS.out_dir, "train"))
create_tfrecord_from_wavs(
processed_valid_wavs,
os.path.join(FLAGS.out_dir, "valid"))
create_tfrecord_from_wavs(
processed_test_wavs,
os.path.join(FLAGS.out_dir, "test"))
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for creating sequence datasets.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
from scipy.sparse import coo_matrix
import tensorflow as tf
# The default number of threads used to process data in parallel.
DEFAULT_PARALLELISM = 12
def sparse_pianoroll_to_dense(pianoroll, min_note, num_notes):
"""Converts a sparse pianoroll to a dense numpy array.
Given a sparse pianoroll, converts it to a dense numpy array of shape
[num_timesteps, num_notes] where entry i,j is 1.0 if note j is active on
timestep i and 0.0 otherwise.
Args:
pianoroll: A sparse pianoroll object, a list of tuples where the i'th tuple
contains the indices of the notes active at timestep i.
min_note: The minimum note in the pianoroll, subtracted from all notes so
that the minimum note becomes 0.
num_notes: The number of possible different note indices, determines the
second dimension of the resulting dense array.
Returns:
dense_pianoroll: A [num_timesteps, num_notes] numpy array of floats.
num_timesteps: A python int, the number of timesteps in the pianoroll.
"""
num_timesteps = len(pianoroll)
inds = []
for time, chord in enumerate(pianoroll):
# Re-index the notes to start from min_note.
inds.extend((time, note-min_note) for note in chord)
shape = [num_timesteps, num_notes]
values = [1.] * len(inds)
sparse_pianoroll = coo_matrix(
(values, ([x[0] for x in inds], [x[1] for x in inds])),
shape=shape)
return sparse_pianoroll.toarray(), num_timesteps
def create_pianoroll_dataset(path,
split,
batch_size,
num_parallel_calls=DEFAULT_PARALLELISM,
shuffle=False,
repeat=False,
min_note=21,
max_note=108):
"""Creates a pianoroll dataset.
Args:
path: The path of a pickle file containing the dataset to load.
split: The split to use, can be train, test, or valid.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
min_note: The minimum note number of the dataset. For all pianoroll datasets
the minimum note is number 21, and changing this affects the dimension of
the data. This is useful mostly for testing.
max_note: The maximum note number of the dataset. For all pianoroll datasets
the maximum note is number 108, and changing this affects the dimension of
the data. This is useful mostly for testing.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros. This tensor is mean-centered, with the mean taken from the pickle
file key 'train_mean'.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
mean: A float Tensor of shape [data_dimension] containing the mean loaded
from the pickle file.
"""
# Load the data from disk.
num_notes = max_note - min_note + 1
with tf.gfile.Open(path, "r") as f:
raw_data = pickle.load(f)
pianorolls = raw_data[split]
mean = raw_data["train_mean"]
num_examples = len(pianorolls)
def pianoroll_generator():
for sparse_pianoroll in pianorolls:
yield sparse_pianoroll_to_dense(sparse_pianoroll, min_note, num_notes)
dataset = tf.data.Dataset.from_generator(
pianoroll_generator,
output_types=(tf.float64, tf.int64),
output_shapes=([None, num_notes], []))
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(num_examples)
# Batch sequences togther, padding them to a common length in time.
dataset = dataset.padded_batch(batch_size,
padded_shapes=([None, num_notes], []))
def process_pianoroll_batch(data, lengths):
"""Create mean-centered and time-major next-step prediction Tensors."""
data = tf.to_float(tf.transpose(data, perm=[1, 0, 2]))
lengths = tf.to_int32(lengths)
targets = data
# Mean center the inputs.
inputs = data - tf.constant(mean, dtype=tf.float32,
shape=[1, 1, mean.shape[0]])
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(inputs, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(tf.transpose(
tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_pianoroll_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(num_examples)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths, tf.constant(mean, dtype=tf.float32)
def create_speech_dataset(path,
batch_size,
samples_per_timestep=200,
num_parallel_calls=DEFAULT_PARALLELISM,
prefetch_buffer_size=2048,
shuffle=False,
repeat=False):
"""Creates a speech dataset.
Args:
path: The path of a possibly sharded TFRecord file containing the data.
batch_size: The batch size. If repeat is False then it is not guaranteed
that the true batch size will match for all batches since batch_size
may not necessarily evenly divide the number of elements.
samples_per_timestep: The number of audio samples per timestep. Used to
reshape the data into sequences of shape [time, samples_per_timestep].
Should not change except for testing -- in all speech datasets 200 is the
number of samples per timestep.
num_parallel_calls: The number of threads to use for parallel processing of
the data.
prefetch_buffer_size: The size of the prefetch queues to use after reading
and processing the raw data.
shuffle: If true, shuffles the order of the dataset.
repeat: If true, repeats the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, samples_per_timestep]. The sequences in inputs are the
sequences in targets shifted one timestep into the future, padded with
zeros.
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, samples_per_timestep].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
"""
filenames = [path]
def read_speech_example(value):
"""Parses a single tf.Example from the TFRecord file."""
decoded = tf.decode_raw(value, out_type=tf.float32)
example = tf.reshape(decoded, [-1, samples_per_timestep])
length = tf.shape(example)[0]
return example, length
# Create the dataset from the TFRecord files
dataset = tf.data.TFRecordDataset(filenames).map(
read_speech_example, num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
if repeat: dataset = dataset.repeat()
if shuffle: dataset = dataset.shuffle(prefetch_buffer_size)
dataset = dataset.padded_batch(
batch_size, padded_shapes=([None, samples_per_timestep], []))
def process_speech_batch(data, lengths):
"""Creates Tensors for next step prediction."""
data = tf.transpose(data, perm=[1, 0, 2])
lengths = tf.to_int32(lengths)
targets = data
# Shift the inputs one step forward in time. Also remove the last timestep
# so that targets and inputs are the same length.
inputs = tf.pad(data, [[1, 0], [0, 0], [0, 0]], mode="CONSTANT")[:-1]
# Mask out unused timesteps.
inputs *= tf.expand_dims(
tf.transpose(tf.sequence_mask(lengths, dtype=inputs.dtype)), 2)
return inputs, targets, lengths
dataset = dataset.map(process_speech_batch,
num_parallel_calls=num_parallel_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
itr = dataset.make_one_shot_iterator()
inputs, targets, lengths = itr.get_next()
return inputs, targets, lengths
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to run training for sequential latent variable models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import runners
# Shared flags.
tf.app.flags.DEFINE_string("mode", "train",
"The mode of the binary. Must be 'train' or 'test'.")
tf.app.flags.DEFINE_string("model", "vrnn",
"Model choice. Currently only 'vrnn' is supported.")
tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.")
tf.app.flags.DEFINE_string("dataset_type", "pianoroll",
"The type of dataset, either 'pianoroll' or 'speech'.")
tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None,
"The dimension of each vector in the data sequence. "
"Defaults to 88 for pianoroll datasets and 200 for speech "
"datasets. Should not need to be changed except for "
"testing.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample "
"algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.")
# Training flags.
tf.app.flags.DEFINE_string("bound", "fivo",
"The bound to optimize. Can be 'elbo', 'iwae', or 'fivo'.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps "
"per sequence.")
tf.app.flags.DEFINE_float("learning_rate", 0.0002,
"The learning rate for ADAM.")
tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.")
# Distributed training flags.
tf.app.flags.DEFINE_string("master", "",
"The BNS name of the TensorFlow master to use.")
tf.app.flags.DEFINE_integer("task", 0,
"Task id of the replica running the training.")
tf.app.flags.DEFINE_integer("ps_tasks", 0,
"Number of tasks in the ps job. If 0 no ps job is used.")
tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.")
# Evaluation flags.
tf.app.flags.DEFINE_string("split", "train",
"Split to evaluate the model on. Can be 'train', 'valid', or 'test'.")
FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88
SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.data_dimension is None:
if FLAGS.dataset_type == "pianoroll":
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
elif FLAGS.dataset_type == "speech":
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
if FLAGS.mode == "train":
runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VRNN classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow as tf
class VRNNCell(snt.AbstractModule):
"""Implementation of a Variational Recurrent Neural Network (VRNN).
Introduced in "A Recurrent Latent Variable Model for Sequential data"
by Chung et al. https://arxiv.org/pdf/1506.02216.pdf.
The VRNN is a sequence model similar to an RNN that uses stochastic latent
variables to improve its representational power. It can be thought of as a
sequential analogue to the variational auto-encoder (VAE).
The VRNN has a deterministic RNN as its backbone, represented by the
sequence of RNN hidden states h_t. At each timestep, the RNN hidden state h_t
is conditioned on the previous sequence element, x_{t-1}, as well as the
latent state from the previous timestep, z_{t-1}.
In this implementation of the VRNN the latent state z_t is Gaussian. The
model's prior over z_t is distributed as Normal(mu_t, diag(sigma_t^2)) where
mu_t and sigma_t are the mean and standard deviation output from a fully
connected network that accepts the rnn hidden state h_t as input.
The approximate posterior (also known as q or the encoder in the VAE
framework) is similar to the prior except that it is conditioned on the
current target, x_t, as well as h_t via a fully connected network.
This implementation uses the 'res_q' parameterization of the approximate
posterior, meaning that instead of directly predicting the mean of z_t, the
approximate posterior predicts the 'residual' from the prior's mean. This is
explored more in section 3.3 of https://arxiv.org/pdf/1605.07571.pdf.
During training, the latent state z_t is sampled from the approximate
posterior and the reparameterization trick is used to provide low-variance
gradients.
The generative distribution p(x_t|z_t, h_t) is conditioned on the latent state
z_t as well as the current RNN hidden state h_t via a fully connected network.
To increase the modeling power of the VRNN, two additional networks are
used to extract features from the data and the latent state. Those networks
are called data_feat_extractor and latent_feat_extractor respectively.
There are a few differences between this exposition and the paper.
First, the indexing scheme for h_t is different than the paper's -- what the
paper calls h_t we call h_{t+1}. This is the same notation used by Fraccaro
et al. to describe the VRNN in the paper linked above. Also, the VRNN paper
uses VAE terminology to refer to the different internal networks, so it
refers to the approximate posterior as the encoder and the generative
distribution as the decoder. This implementation also renamed the functions
phi_x and phi_z in the paper to data_feat_extractor and latent_feat_extractor.
"""
def __init__(self,
rnn_cell,
data_feat_extractor,
latent_feat_extractor,
prior,
approx_posterior,
generative,
random_seed=None,
name="vrnn"):
"""Creates a VRNN cell.
Args:
rnn_cell: A subclass of tf.nn.rnn_cell.RNNCell that will form the
deterministic backbone of the VRNN. The inputs to the RNN will be the
encoded latent state of the previous timestep with shape
[batch_size, encoded_latent_size] as well as the encoded input of the
current timestep, a Tensor of shape [batch_size, encoded_data_size].
data_feat_extractor: A callable that accepts a batch of data x_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument the inputs x_t, a Tensor of the shape
[batch_size, data_size] and return a Tensor of shape
[batch_size, encoded_data_size]. This callable will be called multiple
times in the VRNN cell so if scoping is not handled correctly then
multiple copies of the variables in this network could be made. It is
recommended to use a snt.nets.MLP module, which takes care of this for
you.
latent_feat_extractor: A callable that accepts a latent state z_t and
'encodes' it, e.g. runs it through a fully connected network. Must
accept as argument a Tensor of shape [batch_size, latent_size] and
return a Tensor of shape [batch_size, encoded_latent_size].
This callable must also have the property 'output_size' defined,
returning encoded_latent_size.
prior: A callable that implements the prior p(z_t|h_t). Must accept as
argument the previous RNN hidden state and return a
tf.contrib.distributions.Normal distribution conditioned on the input.
approx_posterior: A callable that implements the approximate posterior
q(z_t|h_t,x_t). Must accept as arguments the encoded target of the
current timestep and the previous RNN hidden state. Must return
a tf.contrib.distributions.Normal distribution conditioned on the
inputs.
generative: A callable that implements the generative distribution
p(x_t|z_t, h_t). Must accept as arguments the encoded latent state
and the RNN hidden state and return a subclass of
tf.contrib.distributions.Distribution that can be used to evaluate
the logprob of the targets.
random_seed: The seed for the random ops. Used mainly for testing.
name: The name of this VRNN.
"""
super(VRNNCell, self).__init__(name=name)
self.rnn_cell = rnn_cell
self.data_feat_extractor = data_feat_extractor
self.latent_feat_extractor = latent_feat_extractor
self.prior = prior
self.approx_posterior = approx_posterior
self.generative = generative
self.random_seed = random_seed
self.encoded_z_size = latent_feat_extractor.output_size
self.state_size = (self.rnn_cell.state_size, self.encoded_z_size)
def zero_state(self, batch_size, dtype):
"""The initial state of the VRNN.
Contains the initial state of the RNN as well as a vector of zeros
corresponding to z_0.
Args:
batch_size: The batch size.
dtype: The data type of the VRNN.
Returns:
zero_state: The initial state of the VRNN.
"""
return (self.rnn_cell.zero_state(batch_size, dtype),
tf.zeros([batch_size, self.encoded_z_size], dtype=dtype))
def _build(self, observations, state, mask):
"""Computes one timestep of the VRNN.
Args:
observations: The observations at the current timestep, a tuple
containing the model inputs and targets as Tensors of shape
[batch_size, data_size].
state: The current state of the VRNN
mask: Tensor of shape [batch_size], 1.0 if the current timestep is active
active, 0.0 if it is not active.
Returns:
log_q_z: The logprob of the latent state according to the approximate
posterior.
log_p_z: The logprob of the latent state according to the prior.
log_p_x_given_z: The conditional log-likelihood, i.e. logprob of the
observation according to the generative distribution.
kl: The analytic kl divergence from q(z) to p(z).
state: The new state of the VRNN.
"""
inputs, targets = observations
rnn_state, prev_latent_encoded = state
# Encode the data.
inputs_encoded = self.data_feat_extractor(inputs)
targets_encoded = self.data_feat_extractor(targets)
# Run the RNN cell.
rnn_inputs = tf.concat([inputs_encoded, prev_latent_encoded], axis=1)
rnn_out, new_rnn_state = self.rnn_cell(rnn_inputs, rnn_state)
# Create the prior and approximate posterior distributions.
latent_dist_prior = self.prior(rnn_out)
latent_dist_q = self.approx_posterior(rnn_out, targets_encoded,
prior_mu=latent_dist_prior.loc)
# Sample the new latent state z and encode it.
latent_state = latent_dist_q.sample(seed=self.random_seed)
latent_encoded = self.latent_feat_extractor(latent_state)
# Calculate probabilities of the latent state according to the prior p
# and approximate posterior q.
log_q_z = tf.reduce_sum(latent_dist_q.log_prob(latent_state), axis=-1)
log_p_z = tf.reduce_sum(latent_dist_prior.log_prob(latent_state), axis=-1)
analytic_kl = tf.reduce_sum(
tf.contrib.distributions.kl_divergence(
latent_dist_q, latent_dist_prior),
axis=-1)
# Create the generative dist. and calculate the logprob of the targets.
generative_dist = self.generative(latent_encoded, rnn_out)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(targets), axis=-1)
return (log_q_z, log_p_z, log_p_x_given_z, analytic_kl,
(new_rnn_state, latent_encoded))
_DEFAULT_INITIALIZERS = {"w": tf.contrib.layers.xavier_initializer(),
"b": tf.zeros_initializer()}
def create_vrnn(
data_size,
latent_size,
generative_class,
rnn_hidden_size=None,
fcnet_hidden_sizes=None,
encoded_data_size=None,
encoded_latent_size=None,
sigma_min=0.0,
raw_sigma_bias=0.25,
generative_bias_init=0.0,
initializers=None,
random_seed=None):
"""A factory method for creating VRNN cells.
Args:
data_size: The dimension of the vectors that make up the data sequences.
latent_size: The size of the stochastic latent state of the VRNN.
generative_class: The class of the generative distribution. Can be either
ConditionalNormalDistribution or ConditionalBernoulliDistribution.
rnn_hidden_size: The hidden state dimension of the RNN that forms the
deterministic part of this VRNN. If None, then it defaults
to latent_size.
fcnet_hidden_sizes: A list of python integers, the size of the hidden
layers of the fully connected networks that parameterize the conditional
distributions of the VRNN. If None, then it defaults to one hidden
layer of size latent_size.
encoded_data_size: The size of the output of the data encoding network. If
None, defaults to latent_size.
encoded_latent_size: The size of the output of the latent state encoding
network. If None, defaults to latent_size.
sigma_min: The minimum value that the standard deviation of the
distribution over the latent state can take.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the neural networks that parameterize the prior and
approximate posterior. Useful for preventing standard deviations close
to zero.
generative_bias_init: A bias to added to the raw output of the fully
connected network that parameterizes the generative distribution. Useful
for initalizing the mean of the distribution to a sensible starting point
such as the mean of the training data. Only used with Bernoulli generative
distributions.
initializers: The variable intitializers to use for the fully connected
networks and RNN cell. Must be a dictionary mapping the keys 'w' and 'b'
to the initializers for the weights and biases. Defaults to xavier for
the weights and zeros for the biases when initializers is None.
random_seed: A random seed for the VRNN resampling operations.
Returns:
model: A VRNNCell object.
"""
if rnn_hidden_size is None:
rnn_hidden_size = latent_size
if fcnet_hidden_sizes is None:
fcnet_hidden_sizes = [latent_size]
if encoded_data_size is None:
encoded_data_size = latent_size
if encoded_latent_size is None:
encoded_latent_size = latent_size
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
data_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_data_size],
initializers=initializers,
name="data_feat_extractor")
latent_feat_extractor = snt.nets.MLP(
output_sizes=fcnet_hidden_sizes + [encoded_latent_size],
initializers=initializers,
name="latent_feat_extractor")
prior = ConditionalNormalDistribution(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="prior")
approx_posterior = NormalApproximatePosterior(
size=latent_size,
hidden_layer_sizes=fcnet_hidden_sizes,
sigma_min=sigma_min,
raw_sigma_bias=raw_sigma_bias,
initializers=initializers,
name="approximate_posterior")
if generative_class == ConditionalBernoulliDistribution:
generative = ConditionalBernoulliDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
bias_init=generative_bias_init,
name="generative")
else:
generative = ConditionalNormalDistribution(
size=data_size,
hidden_layer_sizes=fcnet_hidden_sizes,
initializers=initializers,
name="generative")
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_hidden_size,
initializer=initializers["w"])
return VRNNCell(rnn_cell, data_feat_extractor, latent_feat_extractor,
prior, approx_posterior, generative, random_seed=random_seed)
class ConditionalNormalDistribution(object):
"""A Normal distribution conditioned on Tensor inputs via a fc network."""
def __init__(self, size, hidden_layer_sizes, sigma_min=0.0,
raw_sigma_bias=0.25, hidden_activation_fn=tf.nn.relu,
initializers=None, name="conditional_normal_distribution"):
"""Creates a conditional Normal distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
sigma_min: The minimum standard deviation allowed, a scalar.
raw_sigma_bias: A scalar that is added to the raw standard deviation
output from the fully connected network. Set to 0.25 by default to
prevent standard deviations close to 0.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intitializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
name: The name of this distribution, used for sonnet scoping.
"""
self.sigma_min = sigma_min
self.raw_sigma_bias = raw_sigma_bias
self.name = name
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [2*size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list, **unused_kwargs):
"""Computes the parameters of a normal distribution based on the inputs."""
inputs = tf.concat(tensor_list, axis=1)
outs = self.fcnet(inputs)
mu, sigma = tf.split(outs, 2, axis=1)
sigma = tf.maximum(tf.nn.softplus(sigma + self.raw_sigma_bias),
self.sigma_min)
return mu, sigma
def __call__(self, *args, **kwargs):
"""Creates a normal distribution conditioned on the inputs."""
mu, sigma = self.condition(args, **kwargs)
return tf.contrib.distributions.Normal(loc=mu, scale=sigma)
class ConditionalBernoulliDistribution(object):
"""A Bernoulli distribution conditioned on Tensor inputs via a fc net."""
def __init__(self, size, hidden_layer_sizes, hidden_activation_fn=tf.nn.relu,
initializers=None, bias_init=0.0,
name="conditional_bernoulli_distribution"):
"""Creates a conditional Bernoulli distribution.
Args:
size: The dimension of the random variable.
hidden_layer_sizes: The sizes of the hidden layers of the fully connected
network used to condition the distribution on the inputs.
hidden_activation_fn: The activation function to use on the hidden layers
of the fully connected network.
initializers: The variable intiializers to use for the fully connected
network. The network is implemented using snt.nets.MLP so it must
be a dictionary mapping the keys 'w' and 'b' to the initializers for
the weights and biases. Defaults to xavier for the weights and zeros
for the biases when initializers is None.
bias_init: A scalar or vector Tensor that is added to the output of the
fully-connected network that parameterizes the mean of this
distribution.
name: The name of this distribution, used for sonnet scoping.
"""
self.bias_init = bias_init
if initializers is None:
initializers = _DEFAULT_INITIALIZERS
self.fcnet = snt.nets.MLP(
output_sizes=hidden_layer_sizes + [size],
activation=hidden_activation_fn,
initializers=initializers,
activate_final=False,
use_bias=True,
name=name + "_fcnet")
def condition(self, tensor_list):
"""Computes the p parameter of the Bernoulli distribution."""
inputs = tf.concat(tensor_list, axis=1)
return self.fcnet(inputs) + self.bias_init
def __call__(self, *args):
p = self.condition(args)
return tf.contrib.distributions.Bernoulli(logits=p)
class NormalApproximatePosterior(ConditionalNormalDistribution):
"""A Normally-distributed approx. posterior with res_q parameterization."""
def condition(self, tensor_list, prior_mu):
"""Generates the mean and variance of the normal distribution.
Args:
tensor_list: The list of Tensors to condition on. Will be concatenated and
fed through a fully connected network.
prior_mu: The mean of the prior distribution associated with this
approximate posterior. Will be added to the mean produced by
this approximate posterior, in res_q fashion.
Returns:
mu: The mean of the approximate posterior.
sigma: The standard deviation of the approximate posterior.
"""
mu, sigma = super(NormalApproximatePosterior, self).condition(tensor_list)
return mu + prior_mu, sigma
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A set of utils for dealing with nested lists and tuples of Tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.util import nest
def map_nested(map_fn, nested):
"""Executes map_fn on every element in a (potentially) nested structure.
Args:
map_fn: A callable to execute on each element in 'nested'.
nested: A potentially nested combination of sequence objects. Sequence
objects include tuples, lists, namedtuples, and all subclasses of
collections.Sequence except strings. See nest.is_sequence for details.
For example [1, ('hello', 4.3)] is a nested structure containing elements
1, 'hello', and 4.3.
Returns:
out_structure: A potentially nested combination of sequence objects with the
same structure as the 'nested' input argument. out_structure
contains the result of applying map_fn to each element in 'nested'. For
example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
"""
out = map(map_fn, nest.flatten(nested))
return nest.pack_sequence_as(nested, out)
def tile_tensors(tensors, multiples):
"""Tiles a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors with rank
greater than or equal to the length of 'multiples'. The Tensors do not
need to have the same rank, but their rank must not be dynamic.
multiples: A python list of ints indicating how to tile each Tensor
in 'tensors'. Similar to the 'multiples' argument to tf.tile.
Returns:
tiled_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tensors' input argument. Contains the result of
applying tf.tile to each Tensor in 'tensors'. When the rank of a Tensor
in 'tensors' is greater than the length of multiples, multiples is padded
at the end with 1s. For example when tiling a 4-dimensional Tensor with
multiples [3, 4], multiples would be padded to [3, 4, 1, 1] before tiling.
"""
def tile_fn(x):
return tf.tile(x, multiples + [1]*(x.shape.ndims - len(multiples)))
return map_nested(tile_fn, tensors)
def gather_tensors(tensors, indices):
"""Performs a tf.gather operation on a set of Tensors.
Args:
tensors: A potentially nested tuple or list of Tensors.
indices: The indices to use for the gather operation.
Returns:
gathered_tensors: A potentially nested tuple or list of Tensors with the
same structure as the 'tensors' input argument. Contains the result of
applying tf.gather(x, indices) on each element x in 'tensors'.
"""
return map_nested(lambda x: tf.gather(x, indices), tensors)
def tas_for_tensors(tensors, length):
"""Unstacks a set of Tensors into TensorArrays.
Args:
tensors: A potentially nested tuple or list of Tensors with length in the
first dimension greater than or equal to the 'length' input argument.
length: The desired length of the TensorArrays.
Returns:
tensorarrays: A potentially nested tuple or list of TensorArrays with the
same structure as 'tensors'. Contains the result of unstacking each Tensor
in 'tensors'.
"""
def map_fn(x):
ta = tf.TensorArray(x.dtype, length, name=x.name.split(':')[0] + '_ta')
return ta.unstack(x[:length, :])
return map_nested(map_fn, tensors)
def read_tas(tas, index):
"""Performs a read operation on a set of TensorArrays.
Args:
tas: A potentially nested tuple or list of TensorArrays with length greater
than 'index'.
index: The location to read from.
Returns:
read_tensors: A potentially nested tuple or list of Tensors with the same
structure as the 'tas' input argument. Contains the result of
performing a read operation at 'index' on each TensorArray in 'tas'.
"""
return map_nested(lambda ta: ta.read(index), tas)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""High-level code for creating and running FIVO-related Tensorflow graphs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import tensorflow as tf
import bounds
from data import datasets
from models import vrnn
def create_dataset_and_model(config, split, shuffle, repeat):
"""Creates the dataset and model for a given config.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. This function expects the properties
batch_size, dataset_path, dataset_type, and latent_size to be defined.
split: The dataset split to load.
shuffle: If true, shuffle the dataset randomly.
repeat: If true, repeat the dataset endlessly.
Returns:
inputs: A batch of input sequences represented as a dense Tensor of shape
[time, batch_size, data_dimension].
targets: A batch of target sequences represented as a dense Tensor of
shape [time, batch_size, data_dimension].
lens: An int Tensor of shape [batch_size] representing the lengths of each
sequence in the batch.
model: A vrnn.VRNNCell model object.
"""
if config.dataset_type == "pianoroll":
inputs, targets, lengths, mean = datasets.create_pianoroll_dataset(
config.dataset_path, split, config.batch_size, shuffle=shuffle,
repeat=repeat)
# Convert the mean of the training set to logit space so it can be used to
# initialize the bias of the generative distribution.
generative_bias_init = -tf.log(
1. / tf.clip_by_value(mean, 0.0001, 0.9999) - 1)
generative_distribution_class = vrnn.ConditionalBernoulliDistribution
elif config.dataset_type == "speech":
inputs, targets, lengths = datasets.create_speech_dataset(
config.dataset_path, config.batch_size,
samples_per_timestep=config.data_dimension, prefetch_buffer_size=1,
shuffle=False, repeat=False)
generative_bias_init = None
generative_distribution_class = vrnn.ConditionalNormalDistribution
model = vrnn.create_vrnn(inputs.get_shape().as_list()[2],
config.latent_size,
generative_distribution_class,
generative_bias_init=generative_bias_init,
raw_sigma_bias=0.5)
return inputs, targets, lengths, model
def restore_checkpoint_if_exists(saver, sess, logdir):
"""Looks for a checkpoint and restores the session from it if found.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
Returns:
True if a checkpoint was found and restored, False otherwise.
"""
checkpoint = tf.train.get_checkpoint_state(logdir)
if checkpoint:
checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
full_checkpoint_path = os.path.join(logdir, checkpoint_name)
saver.restore(sess, full_checkpoint_path)
return True
return False
def wait_for_checkpoint(saver, sess, logdir):
"""Loops until the session is restored from a checkpoint in logdir.
Args:
saver: A tf.train.Saver for restoring the session.
sess: A TensorFlow session.
logdir: The directory to look for checkpoints in.
"""
while True:
if restore_checkpoint_if_exists(saver, sess, logdir):
break
else:
tf.logging.info("Checkpoint not found in %s, sleeping for 60 seconds."
% logdir)
time.sleep(60)
def run_train(config):
"""Runs training for a sequential latent variable model.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_logging_hook(step, bound_value):
"""Creates a logging hook that prints the bound value periodically."""
bound_label = config.bound + " bound"
if config.normalize_by_seq_len:
bound_label += " per timestep"
else:
bound_label += " per sequence"
def summary_formatter(log_dict):
return "Step %d, %s: %f" % (
log_dict["step"], bound_label, log_dict["bound_value"])
logging_hook = tf.train.LoggingTensorHook(
{"step": step, "bound_value": bound_value},
every_n_iter=config.summarize_every,
formatter=summary_formatter)
return logging_hook
def create_loss():
"""Creates the loss to be optimized.
Returns:
bound: A float Tensor containing the value of the bound that is
being optimized.
loss: A float Tensor that when differentiated yields the gradients
to apply to the model. Should be optimized via gradient descent.
"""
inputs, targets, lengths, model = create_dataset_and_model(
config, split="train", shuffle=True, repeat=True)
# Compute lower bounds on the log likelihood.
if config.bound == "elbo":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
elif config.bound == "iwae":
ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
elif config.bound == "fivo":
ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
# Compute loss scaled by number of timesteps.
ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
ll_per_seq = tf.reduce_mean(ll_per_seq)
tf.summary.scalar("train_ll_per_seq", ll_per_seq)
tf.summary.scalar("train_ll_per_t", ll_per_t)
if config.normalize_by_seq_len:
return ll_per_t, -ll_per_t
else:
return ll_per_seq, -ll_per_seq
def create_graph():
"""Creates the training graph."""
global_step = tf.train.get_or_create_global_step()
bound, loss = create_loss()
opt = tf.train.AdamOptimizer(config.learning_rate)
grads = opt.compute_gradients(loss, var_list=tf.trainable_variables())
train_op = opt.apply_gradients(grads, global_step=global_step)
return bound, train_op, global_step
device = tf.train.replica_device_setter(ps_tasks=config.ps_tasks)
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
with tf.device(device):
bound, train_op, global_step = create_graph()
log_hook = create_logging_hook(global_step, bound)
start_training = not config.stagger_workers
with tf.train.MonitoredTrainingSession(
master=config.master,
is_chief=config.task == 0,
hooks=[log_hook],
checkpoint_dir=config.logdir,
save_checkpoint_secs=120,
save_summaries_steps=config.summarize_every,
log_step_count_steps=config.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > config.max_steps: break
if config.task > 0 and not start_training:
cur_step = sess.run(global_step)
tf.logging.info("task %d not active yet, sleeping at step %d" %
(config.task, cur_step))
time.sleep(30)
if cur_step >= config.task * 1000:
start_training = True
else:
_, cur_step = sess.run([train_op, global_step])
def run_eval(config):
"""Runs evaluation for a sequential latent variable model.
This method runs only one evaluation over the dataset, writes summaries to
disk, and then terminates. It does not loop indefinitely.
Args:
config: A configuration object with config values accessible as properties.
Most likely a FLAGS object. For a list of expected properties and their
meaning see the flags defined in fivo.py.
"""
def create_graph():
"""Creates the evaluation graph.
Returns:
lower_bounds: A tuple of float Tensors containing the values of the 3
evidence lower bounds, summed across the batch.
total_batch_length: The total number of timesteps in the batch, summed
across batch examples.
batch_size: The batch size.
global_step: The global step the checkpoint was loaded from.
"""
global_step = tf.train.get_or_create_global_step()
inputs, targets, lengths, model = create_dataset_and_model(
config, split=config.split, shuffle=False, repeat=False)
# Compute lower bounds on the log likelihood.
elbo_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=1)
iwae_ll_per_seq, _, _, _ = bounds.iwae(
model, (inputs, targets), lengths, num_samples=config.num_samples)
fivo_ll_per_seq, _, _, _, _ = bounds.fivo(
model, (inputs, targets), lengths, num_samples=config.num_samples,
resampling_criterion=bounds.ess_criterion)
elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
batch_size = tf.shape(lengths)[0]
total_batch_length = tf.reduce_sum(lengths)
return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
global_step)
def average_bounds_over_dataset(lower_bounds, total_batch_length, batch_size,
sess):
"""Computes the values of the bounds, averaged over the datset.
Args:
lower_bounds: Tuple of float Tensors containing the values of the bounds
evaluated on a single batch.
total_batch_length: Integer Tensor that represents the total number of
timesteps in the current batch.
batch_size: Integer Tensor containing the batch size. This can vary if the
requested batch_size does not evenly divide the size of the dataset.
sess: A TensorFlow Session object.
Returns:
ll_per_t: A length 3 numpy array of floats containing each bound's average
value, normalized by the total number of timesteps in the datset. Can
be interpreted as a lower bound on the average log likelihood per
timestep in the dataset.
ll_per_seq: A length 3 numpy array of floats containing each bound's
average value, normalized by the number of sequences in the dataset.
Can be interpreted as a lower bound on the average log likelihood per
sequence in the datset.
"""
total_ll = np.zeros(3, dtype=np.float64)
total_n_elems = 0.0
total_length = 0.0
while True:
try:
outs = sess.run([lower_bounds, batch_size, total_batch_length])
except tf.errors.OutOfRangeError:
break
total_ll += outs[0]
total_n_elems += outs[1]
total_length += outs[2]
ll_per_t = total_ll / total_length
ll_per_seq = total_ll / total_n_elems
return ll_per_t, ll_per_seq
def summarize_lls(lls_per_t, lls_per_seq, summary_writer, step):
"""Creates log-likelihood lower bound summaries and writes them to disk.
Args:
lls_per_t: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of timesteps.
lls_per_seq: An array of 3 python floats, contains the values of the
evaluated bounds normalized by the number of sequences.
summary_writer: A tf.SummaryWriter.
step: The current global step.
"""
def scalar_summary(name, value):
value = tf.Summary.Value(tag=name, simple_value=value)
return tf.Summary(value=[value])
for i, bound in enumerate(["elbo", "iwae", "fivo"]):
per_t_summary = scalar_summary("%s/%s_ll_per_t" % (config.split, bound),
lls_per_t[i])
per_seq_summary = scalar_summary("%s/%s_ll_per_seq" %
(config.split, bound),
lls_per_seq[i])
summary_writer.add_summary(per_t_summary, global_step=step)
summary_writer.add_summary(per_seq_summary, global_step=step)
summary_writer.flush()
with tf.Graph().as_default():
if config.random_seed: tf.set_random_seed(config.random_seed)
lower_bounds, total_batch_length, batch_size, global_step = create_graph()
summary_dir = config.logdir + "/" + config.split
summary_writer = tf.summary.FileWriter(
summary_dir, flush_secs=15, max_queue=100)
saver = tf.train.Saver()
with tf.train.SingularMonitoredSession() as sess:
wait_for_checkpoint(saver, sess, config.logdir)
step = sess.run(global_step)
tf.logging.info("Model restored from step %d, evaluating." % step)
ll_per_t, ll_per_seq = average_bounds_over_dataset(
lower_bounds, total_batch_length, batch_size, sess)
summarize_lls(ll_per_t, ll_per_seq, summary_writer, step)
tf.logging.info("%s elbo ll/t: %f, iwae ll/t: %f fivo ll/t: %f",
config.split, ll_per_t[0], ll_per_t[1], ll_per_t[2])
tf.logging.info("%s elbo ll/seq: %f, iwae ll/seq: %f fivo ll/seq: %f",
config.split, ll_per_seq[0], ll_per_seq[1], ll_per_seq[2])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment