Commit 46dea625 authored by Dieterich Lawson's avatar Dieterich Lawson
Browse files

Updating fivo codebase

parent 5856878d
*.pkl binary
*.tfrecord binary
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
...@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma ...@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO). This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data. Additionally it contains several sequential latent variable model implementations:
* Variational recurrent neural network (VRNN)
* Stochastic recurrent neural network (SRNN)
* Gaussian hidden Markov model with linear conditionals (GHMM)
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
#### Directory Structure #### Directory Structure
The important parts of the code are organized as follows. The important parts of the code are organized as follows.
``` ```
fivo.py # main script, contains flag definitions run_fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation fivo
bounds.py # code for computing each bound ├─smc.py # a sequential Monte Carlo implementation
data ├─bounds.py # code for computing each bound, uses smc.py
├── datasets.py # readers for pianoroll and speech datasets ├─runners.py # code for VRNN and SRNN training and evaluation
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets ├─ghmm_runners.py # code for GHMM training and evaluation
└── create_timit_dataset.py # preprocesses the TIMIT dataset ├─data
models | ├─datasets.py # readers for pianoroll and speech datasets
└── vrnn.py # variational RNN implementation | ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
└─models
├─base.py # base classes used in other models
├─vrnn.py # VRNN implementation
├─srnn.py # SRNN implementation
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
bin bin
├── run_train.sh # an example script that runs training ├─run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation ├─run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files ├─run_sample.sh # an example script that runs sampling
├─run_tests.sh # a script that runs all tests
└─download_pianorolls.sh # a script that downloads pianoroll files
``` ```
### Training on Pianorolls ### Pianorolls
Requirements before we start: Requirements before we start:
...@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl ...@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
#### Training #### Training
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`: Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
``` ```
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
...@@ -75,26 +89,24 @@ python fivo.py \ ...@@ -75,26 +89,24 @@ python fivo.py \
--dataset_type="pianoroll" --dataset_type="pianoroll"
``` ```
You should see output that looks something like this (with a lot of extra logging cruft): You should see output that looks something like this (with extra logging cruft):
``` ```
Step 1, fivo bound per timestep: -11.801050 Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
global_step/sec: 9.89825 Step 1, fivo bound per timestep: -11.322491
Step 101, fivo bound per timestep: -11.198309 global_step/sec: 7.49971
global_step/sec: 9.55475 Step 101, fivo bound per timestep: -11.399275
Step 201, fivo bound per timestep: -11.287262 global_step/sec: 8.04498
global_step/sec: 9.68146 Step 201, fivo bound per timestep: -11.174991
step 301, fivo bound per timestep: -11.316490 global_step/sec: 8.03989
global_step/sec: 9.94295 Step 301, fivo bound per timestep: -11.073008
Step 401, fivo bound per timestep: -11.151743
``` ```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation #### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set: You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
``` ```
python fivo.py \ python run_fivo.py \
--mode=eval \ --mode=eval \
--split=test \ --split=test \
--alsologtostderr \ --alsologtostderr \
...@@ -108,12 +120,52 @@ python fivo.py \ ...@@ -108,12 +120,52 @@ python fivo.py \
You should see output like this: You should see output like this:
``` ```
Model restored from step 1, evaluating. Restoring parameters from /tmp/fivo/model.ckpt-0
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939 Model restored from step 0, evaluating.
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490 test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
``` ```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds. The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
#### Sampling
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
```
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
```
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
You should see very little output.
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Running local_init_op.
Done running local_init_op.
```
Loading the samples with `np.load` confirms that we conditioned the model on 4
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
```
>>> import numpy as np
>>> x = np.load("/tmp/fivo/samples.npz")
>>> x[()]['prefixes'].shape
(25, 4, 88)
>>> x[()]['samples'].shape
(50, 4, 4, 88)
```
### Training on TIMIT ### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`. The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
...@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169 ...@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169
#### Training on TIMIT #### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched. This is very similar to training on pianoroll datasets, with just a few flags switched.
``` ```
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
...@@ -149,6 +201,10 @@ python fivo.py \ ...@@ -149,6 +201,10 @@ python fivo.py \
--dataset_path="$TIMIT_DIR/train" \ --dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech" --dataset_type="speech"
``` ```
Evaluation and sampling are similar.
### Tests
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
### Contact ### Contact
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \ python run_fivo.py \
--mode=eval \ --mode=eval \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
......
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of sampling from the model.
PIANOROLL_DIR=$HOME/pianorolls
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
python -m fivo.smc_test && \
python -m fivo.bounds_test && \
python -m fivo.nested_utils_test && \
python -m fivo.data.datasets_test && \
python -m fivo.models.ghmm_test && \
python -m fivo.models.vrnn_test && \
python -m fivo.models.srnn_test && \
python -m fivo.ghmm_runners_test && \
python -m fivo.runners_test
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
......
An experimental codebase for running simple examples.
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import models
def make_long_chain_dataset(
state_size=1,
num_obs=5,
steps_per_obs=3,
variance=1.,
observation_variance=1.,
batch_size=4,
num_samples=1,
observation_type=models.STANDARD_OBSERVATION,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
dtype="float32"):
"""Creates a long chain data generating process.
Creates a tf.data.Dataset that provides batches of data from a long
chain.
Args:
state_size: The dimension of the state space of the process.
num_obs: The number of observations in the chain.
steps_per_obs: The number of steps between each observation.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
dtype: The datatype of the states and observations.
Returns:
dataset: A tf.data.Dataset that can be iterated over.
"""
num_timesteps = num_obs * steps_per_obs
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
observations = []
# z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=np.sqrt(variance)).astype(dtype))
# start at 1 because we've already generated z0
# go to num_timesteps+1 because we want to include the num_timesteps-th step
for t in xrange(1, num_timesteps+1):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance))
states.append(new_state.astype(dtype))
if t % steps_per_obs == 0:
if fixed_observation is None:
if observation_type == models.SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == models.ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == models.STANDARD_OBSERVATION:
loc = states[-1]
new_obs = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(observation_variance)).astype(dtype)
else:
new_obs = np.ones([state_size])* fixed_observation
observations.append(new_obs)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size], [num_obs, state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return dataset
def make_dataset(bs=None,
state_size=1,
num_timesteps=10,
variance=1.,
prior_type="unimodal",
bimodal_prior_weight=0.5,
bimodal_prior_mean=1,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
batch_size=4,
num_samples=1,
dtype='float32'):
"""Creates a data generating process.
Creates a tf.data.Dataset that provides batches of data.
Args:
bs: The parameters of the data generating process. If None, new bs are
randomly generated.
state_size: The dimension of the state space of the process.
num_timesteps: The length of the state sequences in the process.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
Returns:
bs: The true bs used to generate the data
dataset: A tf.data.Dataset that can be iterated over.
"""
if bs is None:
bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)]
tf.logging.info("data generating processs bs: %s",
np.array(bs).reshape(num_timesteps))
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
if prior_type == "unimodal" or prior_type == "nonlinear":
# Prior is Normal(0, sqrt(variance)).
states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype))
elif prior_type == "bimodal":
if np.random.uniform() > bimodal_prior_weight:
loc = bimodal_prior_mean
else:
loc = - bimodal_prior_mean
states.append(np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)
).astype(dtype))
for t in xrange(num_timesteps):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
loc += bs[t]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)).astype(dtype)
states.append(new_state)
if fixed_observation is None:
observation = states[-1]
else:
observation = np.ones_like(states[-1]) * fixed_observation
yield np.array(states[:-1]), observation
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps, state_size], [state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return np.array(bs), dataset
This diff is collapsed.
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model="forward"
T=5
num_obs=1
var=0.1
n=4
lr=0.0001
bound="fivo-aux"
q_type="normal"
resampling_method="multinomial"
rgrad="true"
p_type="unimodal"
use_bs=false
LOGDIR=/tmp/fivo/model-$model-$bound-$resampling_method-resampling-rgrad-$rgrad-T-$T-var-$var-n-$n-lr-$lr-q-$q_type-p-$p_type
python train.py \
--logdir=$LOGDIR \
--model=$model \
--bound=$bound \
--q_type=$q_type \
--p_type=$p_type \
--variance=$var \
--use_resampling_grads=$rgrad \
--resampling=always \
--resampling_method=$resampling_method \
--batch_size=4 \
--num_samples=$n \
--num_timesteps=$T \
--num_eval_samples=256 \
--summarize_every=100 \
--learning_rate=$lr \
--decay_steps=1000000 \
--max_steps=1000000000 \
--random_seed=1234 \
--train_p=false \
--use_bs=$use_bs \
--alsologtostderr
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
import models
def summarize_ess(weights, only_last_timestep=False):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps = len(weights)
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64)
for i in range(num_timesteps):
if only_last_timestep and i < num_timesteps-1: continue
w = tf.nn.softmax(weights[i], dim=0)
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True)
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1)
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0))
tf.summary.scalar("ess/%d" % i, ess)
tf.summary.scalar("ese/%d" % i, ess / batch_size)
tf.summary.scalar("weight_variance/%d" % i, variance)
def summarize_particles(states, weights, observation, model):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps = len(weights)
num_samples, batch_size = weights[0].get_shape().as_list()
# get q0 information for plotting
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0)
q0_loc = q0_dist.loc[0:batch_size, 0]
q0_scale = q0_dist.scale[0:batch_size, 0]
# get posterior information for plotting
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance,
tf.reduce_sum(model.p.bs), model.p.num_timesteps)
# Reshape states and weights to be [time, num_samples, batch_size]
states = tf.stack(states)
weights = tf.stack(weights)
# normalize the weights over the sample dimension
weights = tf.nn.softmax(weights, dim=1)
states = tf.reshape(states, tf.shape(weights))
ess = 1./tf.reduce_sum(tf.square(weights), axis=1)
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps, _, batch_size = states_batch.shape
plots = []
for i in range(batch_size):
states = states_batch[:,:,i]
weights = weights_batch[:,:,i]
observation = observation_batch[i]
ess = ess_batch[:,i]
q0_loc = q0[0][i]
q0_scale = q0[1][i]
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows = 8
num_rows = (num_timesteps + 1) * (bar_rows + 1)
gs = gridspec.GridSpec(num_rows, 1)
# Figure out how wide to make the plot
prior_lims = (post[1] * -2, post[1] * 2)
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale),
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale))
state_width = states.max() - states.min()
state_lims = (states.min() - state_width * 0.15,
states.max() + state_width * 0.15)
lims = (min(prior_lims[0], q_lims[0], state_lims[0]),
max(prior_lims[1], q_lims[1], state_lims[1]))
# plot the posterior
z0 = np.arange(lims[0], lims[1], 0.1)
alpha, pos_mu, sigma_sq, B, T = post
neg_mu = -pos_mu
scale = np.sqrt((T + 1) * sigma_sq)
p_zn = (
alpha * scipy.stats.norm.pdf(
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) *
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale))
p_z0 = (
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq))
+ (1 - alpha) * scipy.stats.norm.pdf(
z0, loc=neg_mu, scale=np.sqrt(sigma_sq)))
p_zn_given_z0 = scipy.stats.norm.pdf(
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq))
post_z0 = (p_z0 * p_zn_given_z0) / p_zn
# plot q
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale)
ax = plt.subplot(gs[0:bar_rows, :])
ax.plot(z0, q_z0, color="blue")
ax.plot(z0, post_z0, color="green")
ax.plot(z0, p_z0, color="red")
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10})
ax.set_xticks([])
ax.set_xlim(*lims)
# plot the states
for t in range(num_timesteps):
start = (t + 1) * (bar_rows + 1)
ax1 = plt.subplot(gs[start:start + bar_rows, :])
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none")
ax1.set_ylabel("t=%d" % t)
ax1.set_xticks([])
ax1.grid(True, which="both")
ax1.set_xlim(*lims)
# plot the observation
ax1.axvline(x=observation, color="red", linestyle="dashed")
# add the ESS
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t],
ha='center', va='center', transform=ax1.transAxes)
# plot the state location histogram
ax2.hist2d(
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys")
ax2.grid(False)
ax2.set_yticks([])
ax2.set_xlim(*lims)
if t != num_timesteps - 1:
ax2.set_xticks([])
fig.canvas.draw()
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
return np.stack(plots)
plots = tf.py_func(_plot_states,
[states, weights, observation, ess, (q0_loc, q0_scale), post],
[tf.uint8])[0]
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"])
def plot_weights(weights, resampled=None):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights = tf.convert_to_tensor(weights)
def _make_plots(weights, resampled):
num_timesteps, num_samples, batch_size = weights.shape
plots = []
for i in range(batch_size):
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4))
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i]))
axes.set_title("Weights")
axes.set_xlabel("Steps")
axes.set_ylim([0, 1])
axes.set_xlim([0, num_timesteps - 1])
for j in np.where(resampled > 0)[0]:
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0)
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plots.append(data)
plt.close(fig)
return np.stack(plots, axis=0)
if resampled is None:
num_timesteps, _, batch_size = weights.get_shape().as_list()
resampled = tf.zeros([num_timesteps], dtype=tf.float32)
plots = tf.py_func(_make_plots,
[tf.nn.softmax(weights, dim=1),
tf.to_float(resampled)], [tf.uint8])[0]
batch_size = weights.get_shape().as_list()[-1]
tf.summary.image(
"weights", plots, batch_size, collections=["infrequent_summaries"])
def summarize_weights(weights, num_timesteps, num_samples):
# weights is [num_timesteps, num_samples, batch_size]
weights = tf.convert_to_tensor(weights)
mean = tf.reduce_mean(weights, axis=1, keepdims=True)
squared_diff = tf.square(weights - mean)
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1)
# average the variance over the batch
variances = tf.reduce_mean(variances, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2])
for t in xrange(num_timesteps):
tf.summary.scalar("weights/variance_%d" % t, variances[t])
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t])
tf.summary.histogram("weights/step_%d" % t, weights[t])
def summarize_learning_signal(rewards, tag):
num_resampling_events, _ = rewards.get_shape().as_list()
mean = tf.reduce_mean(rewards, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1)
reward_square = tf.reduce_mean(tf.square(rewards), axis=1)
for t in xrange(num_resampling_events):
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t])
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t])
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t])
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t])
def summarize_qs(model, observation, states):
model.q.summarize_weights()
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")):
states = [tf.zeros_like(states[0])] + states[:-1]
for t, prev_state in enumerate(states):
p = model.p.posterior(observation, prev_state, t)
q = model.q.q_zt(observation, prev_state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q))
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl))
mean_diff = q.loc - p.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / p.loc)
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(q.scale) - tf.square(p.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale))
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_rs(model, states):
model.r.summarize_weights()
for t, state in enumerate(states):
true_r = model.p.lookahead(state, t)
r = model.r.r_xn(state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r))
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl))
mean_diff = true_r.loc - r.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / true_r.loc)
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale))
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True):
if hasattr(model.p, "bs"):
model_b = tf.reduce_sum(model.p.bs, axis=0)
true_b = tf.reduce_sum(true_bs, axis=0)
abs_err = tf.abs(model_b - true_b)
rel_err = abs_err / true_b
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b))
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b))
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err))
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def summarize_grads(grads, loss_name):
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99)
vectorized_grads = tf.concat(
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0)
new_second_moments = tf.square(vectorized_grads)
new_first_moments = vectorized_grads
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments])
first_moments = grad_ema.average(new_first_moments)
second_moments = grad_ema.average(new_second_moments)
variances = second_moments - tf.square(first_moments)
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances))
tf.summary.histogram("grad_variance/%s" % loss_name, variances)
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments)
return maintain_grad_ema_op
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from fivo.test_utils import create_vrnn
from fivo import bounds
class BoundsTest(tf.test.TestCase):
def test_elbo(self):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=1,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, _, _ = sess.run(outs)
self.assertAllClose([-21.615765, -13.614225], log_p_hat)
def test_iwae(self):
"""A golden-value test for the IWAE bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=4,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, _ = sess.run(outs)
self.assertAllClose([-23.301426, -13.64028], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-6.2539978, -4.37615728, -7.43738699, -7.85044909],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-9.19093227, -8.01637268, -11.64603615, -10.51128292],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-12.20609856, -10.47217369, -13.66270638, -13.46115875],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-16.14766312, -15.57472229, -17.47755432, -17.98189926],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-20.07182884, -18.43191147, -20.1606636, -21.45263863],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-24.10270691, -22.20865822, -24.14675522, -25.27248383],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]]])
self.assertAllClose(weights_gt, weights)
def test_fivo(self):
"""A golden-value test for the FIVO bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.98902512, -14.21689224], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-2.67100811, -2.30541706, -2.34178066, -2.81751347],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-5.65190411, -5.94563246, -6.55041981, -5.4783473],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-8.71947861, -8.40143299, -8.54593086, -8.42822266],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-12.7003831, -13.5039815, -12.3569726, -12.9489622],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-16.4520301, -16.3611698, -15.0314846, -16.4197006],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-20.7010765, -20.1379165, -19.0020351, -20.2395458],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_relaxed(self):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.942394, -14.273882], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074118, -4.91751575, -5.03293514],
[-2.99690628, -3.17782831, -4.50084877, -3.48536515]],
[[-2.84939098, -2.30087185, -2.35649204, -2.48417377],
[-8.27518654, -6.71545172, -8.96199131, -7.05567837]],
[[-5.92327023, -5.9433074, -6.5826683, -5.04259014],
[-12.34527206, -11.54284668, -11.86675072, -9.69417477]],
[[-8.95323944, -8.40061855, -8.52760506, -7.99130583],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-12.87836456, -13.49628639, -12.31680107, -12.74228859],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-16.78347397, -16.35150909, -14.98797417, -16.35162735],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-20.81165886, -20.1307621, -18.92229652, -20.17458153],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_aux_relaxed(self):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234,
use_tilt=True)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-23.1395, -14.271059], log_p_hat)
weights_gt = np.array(
[[[-5.19826221, -3.55476403, -5.98663855, -6.08058834],
[-6.31685925, -5.70243931, -7.07638931, -6.18138981]],
[[-3.97986865, -3.58831525, -3.85753584, -3.5010016],
[-11.38203049, -8.66213989, -11.23646641, -10.02024746]],
[[-6.62269831, -6.36680222, -6.78096485, -5.80072498],
[-3.55419445, -8.11326408, -3.48766923, -3.08593249]],
[[-10.56472301, -10.16084099, -9.96741676, -8.5270071],
[-6.04880285, -7.80853653, -4.72652149, -3.49711013]],
[[-13.36585426, -16.08720398, -13.33416367, -13.1017189],
[-0., -0., -0., -0.]],
[[-17.54233551, -17.35167503, -16.79163361, -16.51471138],
[0., -0., -0., -0.]],
[[-19.74024963, -18.69452858, -17.76246452, -18.76182365],
[0., -0., -0., -0.]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array([[1., 0.],
[0., 1.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
if __name__ == "__main__":
np.set_printoptions(threshold=np.nan) # Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords. """Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment