Commit 46dea625 authored by Dieterich Lawson's avatar Dieterich Lawson
Browse files

Updating fivo codebase

parent 5856878d
*.pkl binary
*.tfrecord binary
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
...@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma ...@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO). This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data. Additionally it contains several sequential latent variable model implementations:
* Variational recurrent neural network (VRNN)
* Stochastic recurrent neural network (SRNN)
* Gaussian hidden Markov model with linear conditionals (GHMM)
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
#### Directory Structure #### Directory Structure
The important parts of the code are organized as follows. The important parts of the code are organized as follows.
``` ```
fivo.py # main script, contains flag definitions run_fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation fivo
bounds.py # code for computing each bound ├─smc.py # a sequential Monte Carlo implementation
data ├─bounds.py # code for computing each bound, uses smc.py
├── datasets.py # readers for pianoroll and speech datasets ├─runners.py # code for VRNN and SRNN training and evaluation
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets ├─ghmm_runners.py # code for GHMM training and evaluation
└── create_timit_dataset.py # preprocesses the TIMIT dataset ├─data
models | ├─datasets.py # readers for pianoroll and speech datasets
└── vrnn.py # variational RNN implementation | ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
└─models
├─base.py # base classes used in other models
├─vrnn.py # VRNN implementation
├─srnn.py # SRNN implementation
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
bin bin
├── run_train.sh # an example script that runs training ├─run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation ├─run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files ├─run_sample.sh # an example script that runs sampling
├─run_tests.sh # a script that runs all tests
└─download_pianorolls.sh # a script that downloads pianoroll files
``` ```
### Training on Pianorolls ### Pianorolls
Requirements before we start: Requirements before we start:
...@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl ...@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
#### Training #### Training
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`: Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
``` ```
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
...@@ -75,26 +89,24 @@ python fivo.py \ ...@@ -75,26 +89,24 @@ python fivo.py \
--dataset_type="pianoroll" --dataset_type="pianoroll"
``` ```
You should see output that looks something like this (with a lot of extra logging cruft): You should see output that looks something like this (with extra logging cruft):
``` ```
Step 1, fivo bound per timestep: -11.801050 Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
global_step/sec: 9.89825 Step 1, fivo bound per timestep: -11.322491
Step 101, fivo bound per timestep: -11.198309 global_step/sec: 7.49971
global_step/sec: 9.55475 Step 101, fivo bound per timestep: -11.399275
Step 201, fivo bound per timestep: -11.287262 global_step/sec: 8.04498
global_step/sec: 9.68146 Step 201, fivo bound per timestep: -11.174991
step 301, fivo bound per timestep: -11.316490 global_step/sec: 8.03989
global_step/sec: 9.94295 Step 301, fivo bound per timestep: -11.073008
Step 401, fivo bound per timestep: -11.151743
``` ```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation #### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set: You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
``` ```
python fivo.py \ python run_fivo.py \
--mode=eval \ --mode=eval \
--split=test \ --split=test \
--alsologtostderr \ --alsologtostderr \
...@@ -108,12 +120,52 @@ python fivo.py \ ...@@ -108,12 +120,52 @@ python fivo.py \
You should see output like this: You should see output like this:
``` ```
Model restored from step 1, evaluating. Restoring parameters from /tmp/fivo/model.ckpt-0
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939 Model restored from step 0, evaluating.
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490 test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
``` ```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds. The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
#### Sampling
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
```
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
```
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
You should see very little output.
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Running local_init_op.
Done running local_init_op.
```
Loading the samples with `np.load` confirms that we conditioned the model on 4
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
```
>>> import numpy as np
>>> x = np.load("/tmp/fivo/samples.npz")
>>> x[()]['prefixes'].shape
(25, 4, 88)
>>> x[()]['samples'].shape
(50, 4, 4, 88)
```
### Training on TIMIT ### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`. The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
...@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169 ...@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169
#### Training on TIMIT #### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched. This is very similar to training on pianoroll datasets, with just a few flags switched.
``` ```
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
...@@ -149,6 +201,10 @@ python fivo.py \ ...@@ -149,6 +201,10 @@ python fivo.py \
--dataset_path="$TIMIT_DIR/train" \ --dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech" --dataset_type="speech"
``` ```
Evaluation and sampling are similar.
### Tests
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
### Contact ### Contact
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \ python run_fivo.py \
--mode=eval \ --mode=eval \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
......
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of sampling from the model.
PIANOROLL_DIR=$HOME/pianorolls
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
python -m fivo.smc_test && \
python -m fivo.bounds_test && \
python -m fivo.nested_utils_test && \
python -m fivo.data.datasets_test && \
python -m fivo.models.ghmm_test && \
python -m fivo.models.vrnn_test && \
python -m fivo.models.srnn_test && \
python -m fivo.ghmm_runners_test && \
python -m fivo.runners_test
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \ python run_fivo.py \
--mode=train \ --mode=train \
--logdir=/tmp/fivo \ --logdir=/tmp/fivo \
--model=vrnn \ --model=vrnn \
......
An experimental codebase for running simple examples.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import summary_utils as summ
Loss = namedtuple("Loss", "name loss vars")
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
def iwae(model, observation, num_timesteps, num_samples=1,
summarize=False):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.shape(observation)[0]
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
log_weights = []
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
states[-1], observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
if summarize:
weight_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weight_acc, perm=[1, 0]),
allow_nan_stats=False)
weight_entropy = weight_dist.entropy()
weight_entropy = tf.reduce_mean(weight_entropy)
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
log_weights.append(log_weight_acc)
# Compute the lower bound on the log evidence.
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
loss = -tf.reduce_mean(log_p_hat)
losses = [Loss("log_p_hat", loss)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return log_p_hat, losses, tf.no_op(), states[1:], log_weights
def multinomial_resampling(log_weights, states, n, b):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
resampling_parameters = tf.transpose(log_weights, perm=[1,0])
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
ancestors = tf.stop_gradient(
resampling_dist.sample(sample_shape=n))
log_probs = resampling_dist.log_prob(ancestors)
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def stratified_resampling(log_weights, states, n, b):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs = resampling_parameters,
allow_nan_stats=False)
ancestors = tf.stop_gradient(
resampling_dist.sample())
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def systematic_resampling(log_weights, states, n, b):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs=resampling_parameters,
allow_nan_stats=True)
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def log_blend(inputs, weights):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
my_max = tf.stop_gradient(
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
)
# Don't ask.
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
return blended
def relaxed_resampling(log_weights, states, num_samples, batch_size,
log_r_x=None, blend_type="log", temperature=0.5,
straight_through=False):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
state_dim = states[0].get_shape().as_list()[-1]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
temperature,
logits=resampling_parameters)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors = resampling_dist.sample(sample_shape=num_samples)
if straight_through:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
dtype=ancestors.dtype)
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
log_probs = resampling_dist.log_prob(ancestors)
if log_r_x is not None and blend_type == "log":
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
log_r_x = log_blend(log_r_x, ancestors)
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
elif log_r_x is not None and blend_type == "linear":
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states.append(log_r_x)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
resampled_states = []
for state in states:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state = tf.matmul(state, ancestor_inds)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
resampled_states.append(next_state)
new_dist = tf.contrib.distributions.Categorical(
logits=resampling_parameters)
if log_r_x is not None and blend_type == "linear":
# If blend type is linear pop off log_r that we added to the states.
log_r_x = tf.squeeze(resampled_states[-1])
resampled_states = resampled_states[:-1]
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
def fivo(model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
use_resampling_grads=True,
resampling_type="multinomial",
resampling_temperature=0.5,
aux=True,
summarize=False):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
log_weights = []
log_weights_all = []
log_p_hats = []
resampling_log_probs = []
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
prev_state, observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
if aux:
if t == num_timesteps - 1:
log_weight -= prev_log_r_zt
else:
log_weight += log_r_zt - prev_log_r_zt
prev_log_r_zt = log_r_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
log_weights_all.append(log_weight_acc)
if resampling_schedule[t]:
# These objects will be resampled
to_resample = [states[-1]]
if aux and "relaxed" not in resampling_type:
to_resample.append(prev_log_r_zt)
# do the resampling
if resampling_type == "multinomial":
(resampled,
resampling_log_prob,
_, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "stratified":
(resampled,
resampling_log_prob,
_, _, _) = stratified_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "systematic":
(resampled,
resampling_log_prob,
_, _, _) = systematic_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif "relaxed" in resampling_type:
if aux:
if resampling_type == "relaxed-logblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="log")
elif resampling_type == "relaxed-linearblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="linear")
elif resampling_type == "relaxed-stateblend":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
elif resampling_type == "relaxed-stateblend-st":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
straight_through=True)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
else:
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
prev_state = resampled[0]
if aux and "relaxed" not in resampling_type:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt = tf.squeeze(resampled[1])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = states[-1]
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
if use_resampling_grads and any(resampling_schedule):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards = tf.stop_gradient(
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
centered_rewards = rewards - baseline
if summarize:
summ.summarize_learning_signal(rewards, "rewards")
summ.summarize_learning_signal(centered_rewards, "centered_rewards")
# compute the loss tensor.
resampling_grads = tf.reduce_sum(
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
else:
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
maintain_baseline_op = tf.no_op()
log_p_hat /= num_timesteps
# we clip off the initial state before returning.
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
def fivo_aux_td(
model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
summarize=False):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
# must be pre-resampling
log_rs = []
# must be post-resampling
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
log_r_tildes = []
log_p_xs = []
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights = []
# contains weight at each timestep before resampling
log_weights_all = []
log_p_hats = []
for t in xrange(num_timesteps):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
# Compute the log weight without log r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log r.
if t == num_timesteps - 1:
log_r = tf.zeros_like(prev_log_r)
else:
p_mu = p_ztplus1.mean()
p_sigma_sq = p_ztplus1.variance()
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight += log_r - prev_log_r
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
# Update accumulators
states.append(zt)
log_weights_all.append(log_weight_acc)
log_p_xs.append(log_p_x_given_z)
log_rs.append(log_r)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes.append(prev_log_r_tilde)
# optionally resample
if resampling_schedule[t]:
# These objects will be resampled
if t < num_timesteps - 1:
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
else:
to_resample = [zt, log_r]
(resampled,
_, _, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
prev_state = resampled[0]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r = tf.squeeze(resampled[1])
if t < num_timesteps -1:
r_tilde_params.append((resampled[2], resampled[3]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = zt
prev_log_r = log_r
if t < num_timesteps - 1:
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x = tf.reshape(tf.stack(log_p_xs),
[num_timesteps, num_samples, batch_size])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r = tf.reshape(tf.stack(log_rs),
[num_timesteps, num_samples, batch_size])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
[num_timesteps, num_samples, batch_size])
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
keepdims=True)
bellman_sos = tf.reduce_mean(tf.square(
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
tf.summary.scalar("bellman_loss", bellman_loss)
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
log_p_hat_collection = list(set(tf.trainable_variables()) -
set(tf.get_collection("R_TILDE_VARS")))
for v in log_p_hat_collection:
tf.add_to_collection("LOG_P_HAT_VARS", v)
log_p_hat /= num_timesteps
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import models
def make_long_chain_dataset(
state_size=1,
num_obs=5,
steps_per_obs=3,
variance=1.,
observation_variance=1.,
batch_size=4,
num_samples=1,
observation_type=models.STANDARD_OBSERVATION,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
dtype="float32"):
"""Creates a long chain data generating process.
Creates a tf.data.Dataset that provides batches of data from a long
chain.
Args:
state_size: The dimension of the state space of the process.
num_obs: The number of observations in the chain.
steps_per_obs: The number of steps between each observation.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
dtype: The datatype of the states and observations.
Returns:
dataset: A tf.data.Dataset that can be iterated over.
"""
num_timesteps = num_obs * steps_per_obs
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
observations = []
# z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=np.sqrt(variance)).astype(dtype))
# start at 1 because we've already generated z0
# go to num_timesteps+1 because we want to include the num_timesteps-th step
for t in xrange(1, num_timesteps+1):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance))
states.append(new_state.astype(dtype))
if t % steps_per_obs == 0:
if fixed_observation is None:
if observation_type == models.SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == models.ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == models.STANDARD_OBSERVATION:
loc = states[-1]
new_obs = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(observation_variance)).astype(dtype)
else:
new_obs = np.ones([state_size])* fixed_observation
observations.append(new_obs)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size], [num_obs, state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return dataset
def make_dataset(bs=None,
state_size=1,
num_timesteps=10,
variance=1.,
prior_type="unimodal",
bimodal_prior_weight=0.5,
bimodal_prior_mean=1,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
batch_size=4,
num_samples=1,
dtype='float32'):
"""Creates a data generating process.
Creates a tf.data.Dataset that provides batches of data.
Args:
bs: The parameters of the data generating process. If None, new bs are
randomly generated.
state_size: The dimension of the state space of the process.
num_timesteps: The length of the state sequences in the process.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
Returns:
bs: The true bs used to generate the data
dataset: A tf.data.Dataset that can be iterated over.
"""
if bs is None:
bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)]
tf.logging.info("data generating processs bs: %s",
np.array(bs).reshape(num_timesteps))
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
if prior_type == "unimodal" or prior_type == "nonlinear":
# Prior is Normal(0, sqrt(variance)).
states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype))
elif prior_type == "bimodal":
if np.random.uniform() > bimodal_prior_weight:
loc = bimodal_prior_mean
else:
loc = - bimodal_prior_mean
states.append(np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)
).astype(dtype))
for t in xrange(num_timesteps):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
loc += bs[t]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)).astype(dtype)
states.append(new_state)
if fixed_observation is None:
observation = states[-1]
else:
observation = np.ones_like(states[-1]) * fixed_observation
yield np.array(states[:-1]), observation
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps, state_size], [state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return np.array(bs), dataset
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sonnet as snt
import tensorflow as tf
import numpy as np
import math
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
class Q(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(
{"w": tf.zeros_initializer, "b": tf.zeros_initializer})
else:
initializers.append(
{"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer})
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus = [
snt.Linear(output_size=state_size,
initializers=initializers[t],
name="q_mu_%d" % t,
custom_getter=custom_getter
)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1))
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
if t != 0:
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0])
class PreviousStateQ(Q):
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](prev_state)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0])
class ObservationQ(Q):
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](observation)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
class SimpleMeanQ(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(tf.zeros_initializer)
else:
initializers.append(tf.random_uniform_initializer(seed=random_seed))
self.mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_mu_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=initializers[t])
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1])
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/%d" % t, f[0])
class R(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
sigma_init=1.,
random_seed=None,
graph_collection_name="R_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus= [
snt.Linear(output_size=state_size,
initializers=initializers,
name="r_mu_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(sigma_init))
for t in xrange(num_timesteps)
]
def r_xn(self, z_t, t):
batch_size = tf.shape(z_t)[0]
r_mu = self.mus[t](z_t)
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
for t in range(len(self.mus) - 1):
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0])
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0])
class P(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.graph_collection_name = graph_collection_name
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="p_b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
else: # p(z_0) is Normal(0,1)
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class ShortChainNonlinearP(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
transition_dist=tf.contrib.distributions.Normal,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.transition_type = transition_type
self.transition_dist = transition_dist
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = self.transition_dist(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, unused_obs, z_ni):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(z_ni)
elif self.transition_type == STANDARD_TRANSITION:
loc = z_ni
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance
return self.transition_dist(
loc=loc, scale=tf.sqrt(generative_sigma_sq))
class BimodalPriorP(object):
def __init__(self,
state_size,
num_timesteps,
mixing_coeff=0.5,
prior_mode_mean=1,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.mixing_coeff = mixing_coeff
self.prior_mode_mean = prior_mode_mean
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1)
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs
zn_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(sigma))
zn_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(sigma))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
zn_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[zn_pos, zn_neg],
validate_args=True)
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
else: # p(z_0) is mixture of two Normals
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean
z0_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance))
z0_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
z0_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[z0_pos, z0_neg],
validate_args=False)
return z0_dist
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class Model(object):
def __init__(self,
p,
q,
r,
state_size,
num_timesteps,
dtype=tf.float32):
self.p = p
self.q = q
self.r = r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q
zt = q_zt.sample()
r_xn = self.r.r_xn(zt, t)
# Calculate the logprobs and sum over the state size.
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
r_sigma_init=1,
variance=1.0,
mixing_coeff=0.5,
prior_mode_mean=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r = R(state_size,
num_timesteps,
sigma_min=sigma_min,
sigma_init=r_sigma_init,
dtype=dtype,
random_seed=random_seed)
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype)
return model
class BackwardsModel(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
self.q_mus = [
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps)
]
self.q_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_mu_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def posterior(self, unused_observation, prev_state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state))
def lookahead(self, state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(state), scale=tf.zeros_like(state))
def q_zt(self, observation, next_state, t):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(next_state)[0]
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1))
q_sigma = tf.maximum(
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def p_zt(self, prev_state, t):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards = self.num_timesteps - t - 1
z_mu_p = prev_state + self.bs[t_backwards]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.ones_like(z_mu_p))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu))
def r(self, z_t, t):
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(z_t)[0]
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1])
r_sigma = tf.maximum(
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def __call__(self, next_state, observation, t):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt = self.q_zt(observation, next_state, t)
# sample from q
zt = q_zt.sample()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt = self.p_zt(zt, t)
# Compute log p(z_{t+1} | z_t)
if t == 0:
log_p_zt = p_zt.log_prob(observation)
else:
log_p_zt = p_zt.log_prob(next_state)
# Compute r prior over zt
r_zt = self.r(zt, t)
log_r_zt = r_zt.log_prob(zt)
# Compute proposal density at zt
log_q_zt = q_zt.log_prob(zt)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
p_z0_dist = tf.contrib.distributions.Normal(
loc=tf.zeros_like(zt), scale=tf.ones_like(zt))
z0_log_prob = p_z0_dist.log_prob(zt)
else:
z0_log_prob = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt)
class LongChainP(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.observation_type = observation_type
self.transition_type = transition_type
def likelihood(self, observations):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise ValueError("Likelihood is not defined for long-chain models")
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = tf.contrib.distributions.Normal(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, z_ni, t):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.observation_type == SQUARED_OBSERVATION:
generative_mu = tf.square(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == ABS_OBSERVATION:
generative_mu = tf.abs(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == STANDARD_OBSERVATION:
generative_mu = z_ni
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance))
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance
return tf.contrib.distributions.Normal(
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq))
class LongChainQ(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.sigma_min = sigma_min
self.dtype = dtype
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs +1
initializers = {
"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer
}
self.mus = [
snt.Linear(output_size=state_size, initializers=initializers)
for t in xrange(self.num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(self.num_timesteps)
]
def first_relevant_obs_index(self, t):
return int(max((t-1)/self.steps_per_obs, 0))
def q_zt(self, observations, prev_state, t):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs))
num_relevant_observations = self.num_obs - first_relevant_obs_index
observations = observations[:,first_relevant_obs_index:,:]
batch_size = tf.shape(prev_state)[0]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1)
q_input = tf.reshape(q_input,
[batch_size, (num_relevant_observations + 1) * self.state_size])
q_mu = self.mus[t](q_input)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
tf.logging.info(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format(
**{"t": t,
"tm1": t-1,
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs,
"obst":self.steps_per_obs*self.num_obs}))
return q_zt
def summarize_weights(self):
pass
class LongChainR(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.dtype = dtype
self.sigma_min = sigma_min
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs + 1
self.sigmas = [
tf.get_variable(
shape=[self.num_future_obs(t)],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(1.0))
for t in range(self.num_timesteps)
]
def first_future_obs_index(self, t):
return int(math.floor(t / self.steps_per_obs))
def num_future_obs(self, t):
return int(self.num_obs - self.first_future_obs_index(t))
def r_xn(self, z_t, t):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf.logging.info(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format(
**{"t": t,
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs,
"end": self.num_timesteps-1}))
batch_size = tf.shape(z_t)[0]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1])
# compute the variance
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
pass
class LongChainModel(object):
def __init__(self,
p,
q,
r,
state_size,
num_obs,
steps_per_obs,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r = r
self.disable_r = disable_r
self.state_size = state_size
self.num_obs = num_obs
self.steps_per_obs = steps_per_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def next_obs_ind(self, t):
return int(math.floor(max(t-1,0)/self.steps_per_obs))
def __call__(self, prev_state, observations, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observations, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_xn = self.r.r_xn(zt, t)
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :])
# sum over state size and observation, leaving the batch index
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2])
else:
log_r_xn = tf.zeros_like(log_p_zt)
if t != 0 and t % self.steps_per_obs == 0:
generative_dist = self.p.generative(zt, t)
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:])
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None,
disable_r=False):
p = LongChainP(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
observation_type=observation_type,
transition_type=transition_type,
dtype=dtype,
random_seed=random_seed)
q = LongChainQ(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
r = LongChainR(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = LongChainModel(
p, q, r, state_size, num_obs, steps_per_obs,
dtype=dtype,
disable_r=disable_r)
return model
class RTilde(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
graph_collection_name="R_TILDE_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.fns = [
snt.Linear(output_size=2*state_size,
initializers=initializers,
name="r_tilde_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
def r_zt(self, z_t, observation, t):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out = self.fns[t](tf.concat([z_t, observation], axis=1))
mu, raw_sigma_sq = tf.split(out, 2, axis=1)
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min)
return mu, sigma_sq
class TDModel(object):
def __init__(self,
p,
q,
r_tilde,
state_size,
num_timesteps,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r_tilde = r_tilde
self.disable_r = disable_r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
# If it isn't the last timestep, compute the distribution over the next z.
if t < self.num_timesteps - 1:
p_ztplus1 = self.p.p_zt(zt, t+1)
else:
p_ztplus1 = None
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1)
else:
r_tilde_mu = None
r_tilde_sigma_sq = None
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
mixing_coeff=0.5,
prior_mode_mean=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r_tilde = RTilde(
state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype)
return model
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
model="forward"
T=5
num_obs=1
var=0.1
n=4
lr=0.0001
bound="fivo-aux"
q_type="normal"
resampling_method="multinomial"
rgrad="true"
p_type="unimodal"
use_bs=false
LOGDIR=/tmp/fivo/model-$model-$bound-$resampling_method-resampling-rgrad-$rgrad-T-$T-var-$var-n-$n-lr-$lr-q-$q_type-p-$p_type
python train.py \
--logdir=$LOGDIR \
--model=$model \
--bound=$bound \
--q_type=$q_type \
--p_type=$p_type \
--variance=$var \
--use_resampling_grads=$rgrad \
--resampling=always \
--resampling_method=$resampling_method \
--batch_size=4 \
--num_samples=$n \
--num_timesteps=$T \
--num_eval_samples=256 \
--summarize_every=100 \
--learning_rate=$lr \
--decay_steps=1000000 \
--max_steps=1000000000 \
--random_seed=1234 \
--train_p=false \
--use_bs=$use_bs \
--alsologtostderr
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for plotting and summarizing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
import models
def summarize_ess(weights, only_last_timestep=False):
"""Plots the effective sample size.
Args:
weights: List of length num_timesteps Tensors of shape
[num_samples, batch_size]
"""
num_timesteps = len(weights)
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64)
for i in range(num_timesteps):
if only_last_timestep and i < num_timesteps-1: continue
w = tf.nn.softmax(weights[i], dim=0)
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True)
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1)
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0))
tf.summary.scalar("ess/%d" % i, ess)
tf.summary.scalar("ese/%d" % i, ess / batch_size)
tf.summary.scalar("weight_variance/%d" % i, variance)
def summarize_particles(states, weights, observation, model):
"""Plots particle locations and weights.
Args:
states: List of length num_timesteps Tensors of shape
[batch_size*num_particles, state_size].
weights: List of length num_timesteps Tensors of shape [num_samples,
batch_size]
observation: Tensor of shape [batch_size*num_samples, state_size]
"""
num_timesteps = len(weights)
num_samples, batch_size = weights[0].get_shape().as_list()
# get q0 information for plotting
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0)
q0_loc = q0_dist.loc[0:batch_size, 0]
q0_scale = q0_dist.scale[0:batch_size, 0]
# get posterior information for plotting
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance,
tf.reduce_sum(model.p.bs), model.p.num_timesteps)
# Reshape states and weights to be [time, num_samples, batch_size]
states = tf.stack(states)
weights = tf.stack(weights)
# normalize the weights over the sample dimension
weights = tf.nn.softmax(weights, dim=1)
states = tf.reshape(states, tf.shape(weights))
ess = 1./tf.reduce_sum(tf.square(weights), axis=1)
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post):
"""
states: [time, num_samples, batch_size]
weights [time, num_samples, batch_size]
observation: [batch_size, 1]
q0: ([batch_size], [batch_size])
post: ...
"""
num_timesteps, _, batch_size = states_batch.shape
plots = []
for i in range(batch_size):
states = states_batch[:,:,i]
weights = weights_batch[:,:,i]
observation = observation_batch[i]
ess = ess_batch[:,i]
q0_loc = q0[0][i]
q0_scale = q0[1][i]
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2))
# Each timestep gets two plots -- a bar plot and a histogram of state locs.
# The bar plot will be bar_rows rows tall.
# The histogram will be 1 row tall.
# There is also 1 extra plot at the top showing the posterior and q.
bar_rows = 8
num_rows = (num_timesteps + 1) * (bar_rows + 1)
gs = gridspec.GridSpec(num_rows, 1)
# Figure out how wide to make the plot
prior_lims = (post[1] * -2, post[1] * 2)
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale),
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale))
state_width = states.max() - states.min()
state_lims = (states.min() - state_width * 0.15,
states.max() + state_width * 0.15)
lims = (min(prior_lims[0], q_lims[0], state_lims[0]),
max(prior_lims[1], q_lims[1], state_lims[1]))
# plot the posterior
z0 = np.arange(lims[0], lims[1], 0.1)
alpha, pos_mu, sigma_sq, B, T = post
neg_mu = -pos_mu
scale = np.sqrt((T + 1) * sigma_sq)
p_zn = (
alpha * scipy.stats.norm.pdf(
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) *
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale))
p_z0 = (
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq))
+ (1 - alpha) * scipy.stats.norm.pdf(
z0, loc=neg_mu, scale=np.sqrt(sigma_sq)))
p_zn_given_z0 = scipy.stats.norm.pdf(
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq))
post_z0 = (p_z0 * p_zn_given_z0) / p_zn
# plot q
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale)
ax = plt.subplot(gs[0:bar_rows, :])
ax.plot(z0, q_z0, color="blue")
ax.plot(z0, post_z0, color="green")
ax.plot(z0, p_z0, color="red")
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10})
ax.set_xticks([])
ax.set_xlim(*lims)
# plot the states
for t in range(num_timesteps):
start = (t + 1) * (bar_rows + 1)
ax1 = plt.subplot(gs[start:start + bar_rows, :])
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :])
# plot the states barplot
# ax1.hist(
# states[t, :],
# weights=weights[t, :],
# bins=50,
# edgecolor="none",
# alpha=0.2)
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none")
ax1.set_ylabel("t=%d" % t)
ax1.set_xticks([])
ax1.grid(True, which="both")
ax1.set_xlim(*lims)
# plot the observation
ax1.axvline(x=observation, color="red", linestyle="dashed")
# add the ESS
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t],
ha='center', va='center', transform=ax1.transAxes)
# plot the state location histogram
ax2.hist2d(
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys")
ax2.grid(False)
ax2.set_yticks([])
ax2.set_xlim(*lims)
if t != num_timesteps - 1:
ax2.set_xticks([])
fig.canvas.draw()
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
return np.stack(plots)
plots = tf.py_func(_plot_states,
[states, weights, observation, ess, (q0_loc, q0_scale), post],
[tf.uint8])[0]
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"])
def plot_weights(weights, resampled=None):
"""Plots the weights and effective sample size from an SMC rollout.
Args:
weights: [num_timesteps, num_samples, batch_size] importance weights
resampled: [num_timesteps] 0/1 indicating if resampling ocurred
"""
weights = tf.convert_to_tensor(weights)
def _make_plots(weights, resampled):
num_timesteps, num_samples, batch_size = weights.shape
plots = []
for i in range(batch_size):
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4))
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i]))
axes.set_title("Weights")
axes.set_xlabel("Steps")
axes.set_ylim([0, 1])
axes.set_xlim([0, num_timesteps - 1])
for j in np.where(resampled > 0)[0]:
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0)
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plots.append(data)
plt.close(fig)
return np.stack(plots, axis=0)
if resampled is None:
num_timesteps, _, batch_size = weights.get_shape().as_list()
resampled = tf.zeros([num_timesteps], dtype=tf.float32)
plots = tf.py_func(_make_plots,
[tf.nn.softmax(weights, dim=1),
tf.to_float(resampled)], [tf.uint8])[0]
batch_size = weights.get_shape().as_list()[-1]
tf.summary.image(
"weights", plots, batch_size, collections=["infrequent_summaries"])
def summarize_weights(weights, num_timesteps, num_samples):
# weights is [num_timesteps, num_samples, batch_size]
weights = tf.convert_to_tensor(weights)
mean = tf.reduce_mean(weights, axis=1, keepdims=True)
squared_diff = tf.square(weights - mean)
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1)
# average the variance over the batch
variances = tf.reduce_mean(variances, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2])
for t in xrange(num_timesteps):
tf.summary.scalar("weights/variance_%d" % t, variances[t])
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t])
tf.summary.histogram("weights/step_%d" % t, weights[t])
def summarize_learning_signal(rewards, tag):
num_resampling_events, _ = rewards.get_shape().as_list()
mean = tf.reduce_mean(rewards, axis=1)
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1)
reward_square = tf.reduce_mean(tf.square(rewards), axis=1)
for t in xrange(num_resampling_events):
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t])
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t])
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t])
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t])
def summarize_qs(model, observation, states):
model.q.summarize_weights()
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")):
states = [tf.zeros_like(states[0])] + states[:-1]
for t, prev_state in enumerate(states):
p = model.p.posterior(observation, prev_state, t)
q = model.q.q_zt(observation, prev_state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q))
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl))
mean_diff = q.loc - p.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / p.loc)
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(q.scale) - tf.square(p.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale))
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_rs(model, states):
model.r.summarize_weights()
for t, state in enumerate(states):
true_r = model.p.lookahead(state, t)
r = model.r.r_xn(state, t)
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r))
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl))
mean_diff = true_r.loc - r.loc
mean_abs_err = tf.abs(mean_diff)
mean_rel_err = tf.abs(mean_diff / true_r.loc)
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t,
tf.reduce_mean(mean_abs_err))
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t,
tf.reduce_mean(mean_rel_err))
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale)
sigma_abs_err = tf.abs(sigma_diff)
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale))
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t,
tf.reduce_mean(sigma_abs_err))
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t,
tf.reduce_mean(sigma_rel_err))
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True):
if hasattr(model.p, "bs"):
model_b = tf.reduce_sum(model.p.bs, axis=0)
true_b = tf.reduce_sum(true_bs, axis=0)
abs_err = tf.abs(model_b - true_b)
rel_err = abs_err / true_b
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b))
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b))
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err))
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err))
#summarize_qs(model, observation, states)
#if bound == "fivo-aux" and summarize_r:
# summarize_rs(model, states)
def summarize_grads(grads, loss_name):
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99)
vectorized_grads = tf.concat(
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0)
new_second_moments = tf.square(vectorized_grads)
new_first_moments = vectorized_grads
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments])
first_moments = grad_ema.average(new_first_moments)
second_moments = grad_ema.average(new_second_moments)
variances = second_moments - tf.square(first_moments)
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances))
tf.summary.histogram("grad_variance/%s" % loss_name, variances)
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments)
return maintain_grad_ema_op
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import numpy as np
import tensorflow as tf
import bounds
import data
import models
import summary_utils as summ
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization.")
tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"],
"The bound to optimize.")
tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"],
"The model to use.")
tf.app.flags.DEFINE_enum("q_type", "normal",
["normal", "simple_mean", "prev_state", "observation"],
"The parameterization to use for q")
tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"],
"The type of prior.")
tf.app.flags.DEFINE_boolean("train_p", True,
"If false, do not train the model p.")
tf.app.flags.DEFINE_integer("state_size", 1,
"The dimensionality of the state space.")
tf.app.flags.DEFINE_float("variance", 1.0,
"The variance of the data generating process.")
tf.app.flags.DEFINE_boolean("use_bs", True,
"If False, initialize all bs to 0.")
tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p.")
tf.app.flags.DEFINE_float("bimodal_prior_mean", None,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p.")
tf.app.flags.DEFINE_float("fixed_observation", None,
"If supplied, fix the observation to a constant value in the"
" data generating process only.")
tf.app.flags.DEFINE_float("r_sigma_init", 1.,
"Value to initialize variance of r to.")
tf.app.flags.DEFINE_enum("observation_type",
models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES,
"The type of observation for the long chain model.")
tf.app.flags.DEFINE_enum("transition_type",
models.STANDARD_TRANSITION, models.TRANSITION_TYPES,
"The type of transition for the long chain model.")
tf.app.flags.DEFINE_float("observation_variance", None,
"The variance of the observation. Defaults to 'variance'")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"Number of timesteps in the sequence.")
tf.app.flags.DEFINE_integer("num_observations", 1,
"The number of observations.")
tf.app.flags.DEFINE_integer("steps_per_observation", 5,
"The number of timesteps between each observation.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"The number of examples per batch.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number particles to use.")
tf.app.flags.DEFINE_integer("num_eval_samples", 512,
"The batch size and # of particles to use for eval.")
tf.app.flags.DEFINE_string("resampling", "always",
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'.")
tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial",
"stratified",
"systematic",
"relaxed-logblend",
"relaxed-stateblend",
"relaxed-linearblend",
"relaxed-stateblend-st",],
"Type of resampling method to use.")
tf.app.flags.DEFINE_boolean("use_resampling_grads", True,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed.")
tf.app.flags.DEFINE_boolean("disable_r", False,
"If false, r is not used for fivo-aux and is set to zeros.")
tf.app.flags.DEFINE_float("learning_rate", 1e-4,
"The learning rate to use for ADAM or SGD.")
tf.app.flags.DEFINE_integer("decay_steps", 25000,
"The number of steps before the learning rate is halved.")
tf.app.flags.DEFINE_integer("max_steps", int(1e6),
"The number of steps to run training for.")
tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux",
"Directory for summaries and checkpoints.")
tf.app.flags.DEFINE_integer("summarize_every", int(1e3),
"The number of steps between each evaluation.")
FLAGS = tf.app.flags.FLAGS
def combine_grad_lists(grad_lists):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict = defaultdict(list)
var_dict = {}
for grad_list in grad_lists:
for grad, var in grad_list:
if grad is not None:
grads_dict[var.name].append(grad)
var_dict[var.name] = var
final_grads = []
for var_name, var in var_dict.iteritems():
grads = grads_dict[var_name]
if len(grads) > 0:
tf.logging.info("Var %s has combined grads from %s." %
(var_name, [g.name for g in grads]))
grad = tf.reduce_sum(grads, axis=0)
else:
tf.logging.info("Var %s has no grads" % var_name)
grad = None
final_grads.append((grad, var))
return final_grads
def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps):
for l in losses:
assert isinstance(l, bounds.Loss)
lr = tf.train.exponential_decay(
learning_rate, global_step, lr_decay_steps, 0.5, staircase=False)
tf.summary.scalar("learning_rate", lr)
opt = tf.train.AdamOptimizer(lr)
ema_ops = []
grads = []
for loss_name, loss, loss_var_collection in losses:
tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" %
(loss_name, loss_var_collection))
g = opt.compute_gradients(loss,
var_list=tf.get_collection(loss_var_collection))
ema_ops.append(summ.summarize_grads(g, loss_name))
grads.append(g)
all_grads = combine_grad_lists(grads)
apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads_op]):
train_op = tf.group(*ema_ops)
return train_op
def add_check_numerics_ops():
check_op = []
for op in tf.get_default_graph().get_operations():
bad = ["logits/Log", "sample/Reshape", "log_prob/mul",
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape",
"entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"]
if all([x not in op.name for x in bad]):
for output in op.outputs:
if output.dtype in [tf.float16, tf.float32, tf.float64]:
if op._get_control_flow_context() is not None: # pylint: disable=protected-access
raise ValueError("`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`.")
message = op.name + ":" + str(output.value_index)
with tf.control_dependencies(check_op):
check_op = [tf.check_numerics(output, message=message)]
return tf.group(*check_op)
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
batch_size, num_samples, num_eval_samples,
resampling_schedule, use_resampling_grads,
learning_rate, lr_decay_steps, dtype="float64"):
num_timesteps = num_obs * steps_per_obs + 1
# Make the dataset.
dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
eval_dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_eval_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
model = models.LongChainModel.create(
state_size,
num_obs,
steps_per_obs,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=tf.as_dtype(dtype),
disable_r=FLAGS.disable_r)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif bound == "fivo" or "fivo-aux":
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood = tf.zeros([], dtype=dtype)
return global_step, train_op, eval_log_p_hat, eval_likelihood
def create_graph(bound, state_size, num_timesteps, batch_size,
num_samples, num_eval_samples, resampling_schedule,
use_resampling_grads, learning_rate, lr_decay_steps,
train_p, dtype='float64'):
if FLAGS.use_bs:
true_bs = None
else:
true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]
# Make the dataset.
true_bs, dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
_, eval_dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=num_eval_samples,
num_samples=num_eval_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
if bound == "fivo-aux-td":
model = models.TDModel.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
else:
model = models.Model.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
r_sigma_init=FLAGS.r_sigma_init,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif "fivo" in bound:
if bound == "fivo-aux-td":
(_, losses, ema_op, _, _) = bounds.fivo_aux_td(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_eval_samples,
summarize=True)
else:
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
else:
eval_likelihood = tf.zeros_like(eval_log_p_hat)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
tf.summary.scalar("likelihood", eval_likelihood)
tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
summarize_r=not bound == "fivo-aux-td")
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return global_step, train_op, eval_log_p_hat, eval_likelihood
def parse_resampling_schedule(schedule, num_timesteps):
schedule = schedule.strip().lower()
if schedule == "always":
return [True] * (num_timesteps - 1) + [False]
elif schedule == "never":
return [False] * num_timesteps
elif "every" in schedule:
n = int(schedule.split("_")[1])
return [(i+1) % n == 0 for i in xrange(num_timesteps)]
else:
sched = [x.strip() == "true" for x in schedule.split(",")]
assert len(
sched
) == num_timesteps, "Wrong number of timesteps in resampling schedule."
return sched
def create_log_hook(step, eval_log_p_hat, eval_likelihood):
def summ_formatter(d):
return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d))
hook = tf.train.LoggingTensorHook(
{
"step": step,
"log_p_hat": eval_log_p_hat,
"likelihood": eval_likelihood,
},
every_n_iter=FLAGS.summarize_every,
formatter=summ_formatter)
return hook
def create_infrequent_summary_hook():
infrequent_summary_hook = tf.train.SummarySaverHook(
save_steps=10000,
output_dir=FLAGS.logdir,
summary_op=tf.summary.merge_all(key="infrequent_summaries")
)
return infrequent_summary_hook
def main(unused_argv):
if FLAGS.model == "long_chain":
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps + 1)
else:
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps)
if FLAGS.random_seed is None:
seed = np.random.randint(0, high=10000)
else:
seed = FLAGS.random_seed
tf.logging.info("Using random seed %d", seed)
if FLAGS.model == "long_chain":
assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type
assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models"
assert not FLAGS.use_bs, "Bs are not supported with long chain models"
assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match."
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models."
if FLAGS.model == "forward":
if "nonlinear" not in FLAGS.p_type:
assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model."
assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model."
assert FLAGS.observation_variance is None, "Forward model does not support observation variance."
assert FLAGS.num_observations == 1, "Forward model only supports 1 observation."
if "relaxed" in FLAGS.resampling_method:
FLAGS.use_resampling_grads = False
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling."
if FLAGS.observation_variance is None:
FLAGS.observation_variance = FLAGS.variance
if FLAGS.p_type == "bimodal":
assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p."
if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy":
assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model."
g = tf.Graph()
with g.as_default():
# Set the seeds.
tf.set_random_seed(seed)
np.random.seed(seed)
if FLAGS.model == "long_chain":
(global_step, train_op, eval_log_p_hat,
eval_likelihood) = create_long_chain_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_observations,
FLAGS.steps_per_observation,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps)
else:
(global_step, train_op,
eval_log_p_hat, eval_likelihood) = create_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_timesteps,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps,
FLAGS.train_p)
log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)]
if len(tf.get_collection("infrequent_summaries")) > 0:
log_hooks.append(create_infrequent_summary_hook())
tf.logging.info("trainable variables:")
tf.logging.info([v.name for v in tf.trainable_variables()])
tf.logging.info("p vars:")
tf.logging.info([v.name for v in tf.get_collection("P_VARS")])
tf.logging.info("q vars:")
tf.logging.info([v.name for v in tf.get_collection("Q_VARS")])
tf.logging.info("r vars:")
tf.logging.info([v.name for v in tf.get_collection("R_VARS")])
tf.logging.info("r tilde vars:")
tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")])
with tf.train.MonitoredTrainingSession(
master="",
is_chief=True,
hooks=log_hooks,
checkpoint_dir=FLAGS.logdir,
save_checkpoint_secs=120,
save_summaries_steps=FLAGS.summarize_every,
log_step_count_steps=FLAGS.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > FLAGS.max_steps:
break
# run a step
_, cur_step = sess.run([train_op, global_step])
if __name__ == "__main__":
tf.app.run(main)
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,13 +23,15 @@ from __future__ import absolute_import ...@@ -23,13 +23,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import tensorflow as tf import tensorflow as tf
import nested_utils as nested from fivo import nested_utils as nested
from fivo import smc
def iwae(cell, def iwae(model,
inputs, observations,
seq_lengths, seq_lengths,
num_samples=1, num_samples=1,
parallel_iterations=30, parallel_iterations=30,
...@@ -45,13 +47,13 @@ def iwae(cell, ...@@ -45,13 +47,13 @@ def iwae(cell,
When num_samples = 1, this bound becomes the evidence lower bound (ELBO). When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args: Args:
cell: A callable that implements one timestep of the model. See model: A subclass of ELBOTrainableSequenceModel that implements one
models/vrnn.py for an example. timestep of the model. See models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each dimensions, which represent time and the batch respectively. The model
timestep 'cell' will be called with a slice of the Tensors in inputs. will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length). sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use. num_samples: The number of samples to use.
...@@ -63,98 +65,28 @@ def iwae(cell, ...@@ -63,98 +65,28 @@ def iwae(cell,
Returns: Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations. log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the kl divergence
from q(z|x) to p(z), averaged over samples.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples] log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence. timesteps past the end of a sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size at each timestep. Will not be valid for timesteps
past the end of a sequence.
""" """
batch_size = tf.shape(seq_lengths)[0] log_p_hat, log_weights, _, final_state = fivo(
max_seq_len = tf.reduce_max(seq_lengths) model,
seq_mask = tf.transpose( observations,
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32), seq_lengths,
perm=[1, 0]) num_samples=num_samples,
if num_samples > 1: resampling_criterion=smc.never_resample_criterion,
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of IWAE computation."""
log_weights_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Update the Tensorarrays and accumulators.
ta_updates = [log_weights_acc, log_ess]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
new_accs = (log_weights_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
parallel_iterations=parallel_iterations, parallel_iterations=parallel_iterations,
swap_memory=swap_memory) swap_memory=swap_memory)
return log_p_hat, log_weights, final_state
log_weights, log_ess = [x.stack() for x in tas]
final_log_weights, kl = accs
log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess
def ess_criterion(num_samples, log_ess, unused_t):
"""A criterion that resamples based on effective sample size."""
return log_ess <= tf.log(num_samples / 2.0)
def fivo(model,
def never_resample_criterion(unused_num_samples, log_ess, unused_t): observations,
"""A criterion that never resamples."""
return tf.cast(tf.zeros_like(log_ess), tf.bool)
def always_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion resamples at every timestep."""
return tf.cast(tf.ones_like(log_ess), tf.bool)
def fivo(cell,
inputs,
seq_lengths, seq_lengths,
num_samples=1, num_samples=1,
resampling_criterion=ess_criterion, resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30, parallel_iterations=30,
swap_memory=True, swap_memory=True,
random_seed=None): random_seed=None):
...@@ -170,21 +102,26 @@ def fivo(cell, ...@@ -170,21 +102,26 @@ def fivo(cell,
When the resampling criterion is "never resample", this bound becomes IWAE. When the resampling criterion is "never resample", this bound becomes IWAE.
Args: Args:
cell: A callable that implements one timestep of the model. See model: A subclass of ELBOTrainableSequenceModel that implements one
models/vrnn.py for an example. timestep of the model. See models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of observations: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each dimensions, which represent time and the batch respectively. The model
timestep 'cell' will be called with a slice of the Tensors in inputs. will be provided with the observations before computing the bound.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length). sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter. num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the effective sample size, filter. Must accept the number of samples, the current log weights,
and the current timestep and return a boolean Tensor of shape [batch_size] and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See indicating whether each particle filter should resample. See
ess_criterion and related functions defined in this file for examples. ess_criterion and related functions for examples. When
resampling_criterion is never_resample_criterion, resampling_fn is ignored
and never called.
resampling_type: The type of resampling, one of "multinomial" or "relaxed".
relaxed_resampling_temperature: A positive temperature only used for relaxed
resampling.
parallel_iterations: The number of parallel iterations to use for the parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided. non-determinism even when random_seed is provided.
...@@ -196,28 +133,17 @@ def fivo(cell, ...@@ -196,28 +133,17 @@ def fivo(cell,
Returns: Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations. log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the sum over time of the kl
divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
this includes kl terms from trajectories that are culled during resampling
steps.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples] log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a are reset to 0. Will not be valid for timesteps past the end of a
sequence. sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size of each particle filter at each timestep. Will not
be valid for timesteps past the end of a sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not. occurred and 0.0 on timesteps when it did not.
""" """
# batch_size represents the number of particle filters running in parallel. # batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0] batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
# Each sequence in the batch will be the input data for a different # Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as: # particle filter. The batch will be laid out as:
...@@ -228,96 +154,164 @@ def fivo(cell, ...@@ -228,96 +154,164 @@ def fivo(cell,
# particle 2 of particle filter 1 # particle 2 of particle filter 1
# ... # ...
# particle num_samples of particle filter batch_size # particle num_samples of particle filter batch_size
if num_samples > 1: observations = nested.tile_tensors(observations, [1, num_samples])
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples]) tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len) model.set_observations(observations, tiled_seq_lengths)
if resampling_type == 'multinomial':
resampling_fn = smc.multinomial_resampling
elif resampling_type == 'relaxed':
resampling_fn = functools.partial(
smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
def transition_fn(prev_state, t):
if prev_state is None:
return model.zero_state(batch_size * num_samples, tf.float32)
return model.propose_and_weight(prev_state, t)
log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
transition_fn,
seq_lengths,
num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
t0 = tf.constant(0, tf.int32) return log_p_hat, log_weights, resampled, final_state
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess', 'resampled']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
log_p_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, log_p_hat_acc, kl_acc)
def while_predicate(t, *unused_args): def fivo_aux_td(
return t < max_seq_len model,
observations,
seq_lengths,
num_samples=1,
resampling_criterion=smc.ess_criterion,
resampling_type='multinomial',
relaxed_resampling_temperature=0.5,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Experimental."""
# batch_size is the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
def while_step(t, rnn_state, tas, accs): # Each sequence in the batch will be the input data for a different
"""Implements one timestep of FIVO computation.""" # particle filter. The batch will be laid out as:
log_weights_acc, log_p_hat_acc, kl_acc = accs # particle 1 of particle filter 1
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t) # particle 1 of particle filter 2
# Run the cell for one step. # ...
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell( # particle 1 of particle filter batch_size
cur_inputs, # particle 2 of particle filter 1
rnn_state, # ...
cur_mask, # particle num_samples of particle filter batch_size
) observations = nested.tile_tensors(observations, [1, num_samples])
# Compute the incremental weight and use it to update the current tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
# accumulated weight. model.set_observations(observations, tiled_seq_lengths)
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask if resampling_type == 'multinomial':
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size]) resampling_fn = smc.multinomial_resampling
log_weights_acc += log_alpha elif resampling_type == 'relaxed':
# Calculate the effective sample size. resampling_fn = functools.partial(
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0) smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0) resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)
log_ess = ess_num - ess_denom
# Calculate the ancestor indices via resampling. Because we maintain the def transition_fn(prev_state, t):
# log unnormalized weights, we pass the weights in as logits, allowing if prev_state is None:
# the distribution object to apply a softmax and normalize them. model_init_state = model.zero_state(batch_size * num_samples, tf.float32)
resampling_dist = tf.contrib.distributions.Categorical( return (tf.zeros([num_samples*batch_size], dtype=tf.float32),
logits=tf.transpose(log_weights_acc, perm=[1, 0])) (tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32),
ancestor_inds = tf.stop_gradient( tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)),
resampling_dist.sample(sample_shape=num_samples, seed=random_seed)) model_init_state)
# Because the batch is flattened and laid out as discussed
# above, we must modify ancestor_inds to index the proper samples. prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
# The particles in the ith filter are distributed every batch_size rows (new_model_state, zt, log_q_zt, log_p_zt,
# in the batch, and offset i rows from the top. So, to correct the indices log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t)
# we multiply by the batch_size and add the proper offset. Crucially, r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
# when ancestor_inds is flattened the layout of the batch is maintained. # Compute the weight without r.
offset = tf.expand_dims(tf.range(batch_size), 0) log_weight = log_p_zt + log_p_x_given_z - log_q_zt
ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1]) # Compute log_r and log_r_tilde.
noresample_inds = tf.range(num_samples * batch_size) p_mu = tf.stop_gradient(p_ztplus1.mean())
# Decide whether or not we should resample; don't resample if we are past p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
# the end of a sequence. log_r = (tf.log(r_tilde_sigma_sq) -
should_resample = resampling_criterion(num_samples, log_ess, t) tf.log(r_tilde_sigma_sq + p_sigma_sq) -
should_resample = tf.logical_and(should_resample, tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
cur_mask[:batch_size] > 0.) # log_r is [num_samples*batch_size, latent_size]. We sum it along the last
float_should_resample = tf.to_float(should_resample) # dimension to compute log r.
ancestor_inds = tf.where( log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
tf.tile(should_resample, [num_samples]), # Compute prev log r tilde
ancestor_inds, prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
noresample_inds) prev_log_r_tilde = -0.5*tf.reduce_sum(
new_state = nested.gather_tensors(new_state, ancestor_inds) tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
# Update the TensorArrays before we reset the weights so that we capture # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
# the incremental weights and not zeros. last_timestep = t >= (tiled_seq_lengths - 1)
ta_updates = [log_weights_acc, log_ess, float_should_resample] log_r = tf.where(last_timestep,
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)] tf.zeros_like(log_r),
# For the particle filters that resampled, update log_p_hat and log_r)
# reset weights to zero. prev_log_r_tilde = tf.where(last_timestep,
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
log_weight += tf.stop_gradient(log_r - prev_log_r)
new_state = (log_r, log_r_tilde, new_model_state)
loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r)
return log_weight, new_state, loop_fn_args
def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t):
if loop_state is None:
return (tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([batch_size], dtype=tf.float32),
tf.zeros([num_samples, batch_size], dtype=tf.float32))
log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
# Compute the log_p_hat update
log_p_hat_update = tf.reduce_logsumexp( log_p_hat_update = tf.reduce_logsumexp(
log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples)) log_weights, axis=0) - tf.log(tf.to_float(num_samples))
log_p_hat_acc += log_p_hat_update * float_should_resample # If it is the last timestep, we always add the update.
log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :], log_p_hat_acc += tf.cond(t >= max_seq_len-1,
[num_samples, 1])) lambda: log_p_hat_update,
new_accs = (log_weights_acc, log_p_hat_acc, kl_acc) lambda: log_p_hat_update * resampled)
return t + 1, new_state, new_tas, new_accs # Compute the Bellman update.
log_r = tf.reshape(log_r, [num_samples, batch_size])
prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size])
log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size])
mask = tf.reshape(mask, [num_samples, batch_size])
# On the first timestep there is no bellman error because there is no
# prev_log_r_tilde.
mask = tf.cond(tf.equal(t, 0),
lambda: tf.zeros_like(mask),
lambda: mask)
# On the first timestep also fix up prev_log_r_tilde, which will be -inf.
prev_log_r_tilde = tf.where(
tf.is_inf(prev_log_r_tilde),
tf.zeros_like(prev_log_r_tilde),
prev_log_r_tilde)
# log_lambda is [num_samples, batch_size]
log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
axis=0, keepdims=True)
bellman_error = mask * tf.square(
prev_log_r_tilde -
tf.stop_gradient(log_lambda + log_p_x_given_z + log_r)
)
bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
# Compute the log_r_diff update
log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size])
return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)
_, _, tas, accs = tf.while_loop( log_weights, resampled, accs = smc.smc(
while_predicate, transition_fn,
while_step, seq_lengths,
loop_vars=(t0, init_states, tas, accs), num_particles=num_samples,
resampling_criterion=resampling_criterion,
resampling_fn=resampling_fn,
loop_fn=loop_fn,
parallel_iterations=parallel_iterations, parallel_iterations=parallel_iterations,
swap_memory=swap_memory) swap_memory=swap_memory)
log_weights, log_ess, resampled = [x.stack() for x in tas] log_p_hat, bellman_loss, log_r_diff = accs
final_log_weights, log_p_hat, kl = accs loss_per_seq = [- log_p_hat, bellman_loss]
# Add in the final weight update to log_p_hat. tf.summary.scalar("bellman_loss",
log_p_hat += (tf.reduce_logsumexp(final_log_weights, axis=0) - tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
tf.log(tf.to_float(num_samples))) tf.summary.scalar("log_r_diff",
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0) tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
log_weights = tf.transpose(log_weights, perm=[0, 2, 1]) return loss_per_seq, log_p_hat, log_weights, resampled
return log_p_hat, kl, log_weights, log_ess, resampled
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for fivo.bounds"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from fivo.test_utils import create_vrnn
from fivo import bounds
class BoundsTest(tf.test.TestCase):
def test_elbo(self):
"""A golden-value test for the ELBO (the IWAE bound with num_samples=1)."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=1,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, _, _ = sess.run(outs)
self.assertAllClose([-21.615765, -13.614225], log_p_hat)
def test_iwae(self):
"""A golden-value test for the IWAE bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.iwae(model, (inputs, targets), lengths, num_samples=4,
parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, _ = sess.run(outs)
self.assertAllClose([-23.301426, -13.64028], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-6.2539978, -4.37615728, -7.43738699, -7.85044909],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-9.19093227, -8.01637268, -11.64603615, -10.51128292],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-12.20609856, -10.47217369, -13.66270638, -13.46115875],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-16.14766312, -15.57472229, -17.47755432, -17.98189926],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-20.07182884, -18.43191147, -20.1606636, -21.45263863],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]],
[[-24.10270691, -22.20865822, -24.14675522, -25.27248383],
[-17.17656708, -16.25190353, -15.28658581, -12.33067703]]])
self.assertAllClose(weights_gt, weights)
def test_fivo(self):
"""A golden-value test for the FIVO bound."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1)
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.98902512, -14.21689224], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074022, -4.91751671, -5.03293562],
[-2.99690723, -3.17782736, -4.50084877, -3.48536515]],
[[-2.67100811, -2.30541706, -2.34178066, -2.81751347],
[-8.27518654, -6.71545124, -8.96198845, -7.05567837]],
[[-5.65190411, -5.94563246, -6.55041981, -5.4783473],
[-12.34527206, -11.54284477, -11.8667469, -9.69417381]],
[[-8.71947861, -8.40143299, -8.54593086, -8.42822266],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-12.7003831, -13.5039815, -12.3569726, -12.9489622],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-16.4520301, -16.3611698, -15.0314846, -16.4197006],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]],
[[-20.7010765, -20.1379165, -19.0020351, -20.2395458],
[-4.28782988, -4.50591278, -3.40847206, -2.63650274]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_relaxed(self):
"""A golden-value test for the FIVO bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-22.942394, -14.273882], log_p_hat)
weights_gt = np.array(
[[[-3.66708851, -2.07074118, -4.91751575, -5.03293514],
[-2.99690628, -3.17782831, -4.50084877, -3.48536515]],
[[-2.84939098, -2.30087185, -2.35649204, -2.48417377],
[-8.27518654, -6.71545172, -8.96199131, -7.05567837]],
[[-5.92327023, -5.9433074, -6.5826683, -5.04259014],
[-12.34527206, -11.54284668, -11.86675072, -9.69417477]],
[[-8.95323944, -8.40061855, -8.52760506, -7.99130583],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-12.87836456, -13.49628639, -12.31680107, -12.74228859],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-16.78347397, -16.35150909, -14.98797417, -16.35162735],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]],
[[-20.81165886, -20.1307621, -18.92229652, -20.17458153],
[-4.58102798, -4.56017351, -3.46283388, -2.65550804]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array(
[[1., 0.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
def test_fivo_aux_relaxed(self):
"""A golden-value test for the FIVO-AUX bound with relaxed sampling."""
tf.set_random_seed(1234)
with self.test_session() as sess:
model, inputs, targets, lengths = create_vrnn(random_seed=1234,
use_tilt=True)
outs = bounds.fivo(model, (inputs, targets), lengths, num_samples=4,
random_seed=1234, parallel_iterations=1,
resampling_type="relaxed")
sess.run(tf.global_variables_initializer())
log_p_hat, weights, resampled, _ = sess.run(outs)
self.assertAllClose([-23.1395, -14.271059], log_p_hat)
weights_gt = np.array(
[[[-5.19826221, -3.55476403, -5.98663855, -6.08058834],
[-6.31685925, -5.70243931, -7.07638931, -6.18138981]],
[[-3.97986865, -3.58831525, -3.85753584, -3.5010016],
[-11.38203049, -8.66213989, -11.23646641, -10.02024746]],
[[-6.62269831, -6.36680222, -6.78096485, -5.80072498],
[-3.55419445, -8.11326408, -3.48766923, -3.08593249]],
[[-10.56472301, -10.16084099, -9.96741676, -8.5270071],
[-6.04880285, -7.80853653, -4.72652149, -3.49711013]],
[[-13.36585426, -16.08720398, -13.33416367, -13.1017189],
[-0., -0., -0., -0.]],
[[-17.54233551, -17.35167503, -16.79163361, -16.51471138],
[0., -0., -0., -0.]],
[[-19.74024963, -18.69452858, -17.76246452, -18.76182365],
[0., -0., -0., -0.]]])
self.assertAllClose(weights_gt, weights)
resampled_gt = np.array([[1., 0.],
[0., 1.],
[0., 0.],
[0., 1.],
[0., 0.],
[0., 0.],
[0., 0.]])
self.assertAllClose(resampled_gt, resampled)
if __name__ == "__main__":
np.set_printoptions(threshold=np.nan) # Used to easily see the gold values.
# Use print(repr(numpy_array)) to print the values.
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocesses TIMIT from raw wavfiles to create a set of TFRecords. """Preprocesses TIMIT from raw wavfiles to create a set of TFRecords.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment