@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
...
@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
Additionally it contains several sequential latent variable model implementations:
* Variational recurrent neural network (VRNN)
* Stochastic recurrent neural network (SRNN)
* Gaussian hidden Markov model with linear conditionals (GHMM)
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
#### Directory Structure
#### Directory Structure
The important parts of the code are organized as follows.
The important parts of the code are organized as follows.
```
```
fivo.py # main script, contains flag definitions
run_fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation
fivo
bounds.py # code for computing each bound
├─smc.py # a sequential Monte Carlo implementation
data
├─bounds.py # code for computing each bound, uses smc.py
├── datasets.py # readers for pianoroll and speech datasets
├─runners.py # code for VRNN and SRNN training and evaluation
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
├─ghmm_runners.py # code for GHMM training and evaluation
└── create_timit_dataset.py # preprocesses the TIMIT dataset
├─data
models
| ├─datasets.py # readers for pianoroll and speech datasets
└── vrnn.py # variational RNN implementation
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
└─models
├─base.py # base classes used in other models
├─vrnn.py # VRNN implementation
├─srnn.py # SRNN implementation
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
bin
bin
├── run_train.sh # an example script that runs training
├─run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation
├─run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files
├─run_sample.sh # an example script that runs sampling
├─run_tests.sh # a script that runs all tests
└─download_pianorolls.sh # a script that downloads pianoroll files
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
```
```
python fivo.py \
python run_fivo.py \
--mode=train \
--mode=train \
--logdir=/tmp/fivo \
--logdir=/tmp/fivo \
--model=vrnn \
--model=vrnn \
...
@@ -75,26 +89,24 @@ python fivo.py \
...
@@ -75,26 +89,24 @@ python fivo.py \
--dataset_type="pianoroll"
--dataset_type="pianoroll"
```
```
You should see output that looks something like this (with a lot of extra logging cruft):
You should see output that looks something like this (with extra logging cruft):
```
```
Step 1, fivo bound per timestep: -11.801050
Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
global_step/sec: 9.89825
Step 1, fivo bound per timestep: -11.322491
Step 101, fivo bound per timestep: -11.198309
global_step/sec: 7.49971
global_step/sec: 9.55475
Step 101, fivo bound per timestep: -11.399275
Step 201, fivo bound per timestep: -11.287262
global_step/sec: 8.04498
global_step/sec: 9.68146
Step 201, fivo bound per timestep: -11.174991
step 301, fivo bound per timestep: -11.316490
global_step/sec: 8.03989
global_step/sec: 9.94295
Step 301, fivo bound per timestep: -11.073008
Step 401, fivo bound per timestep: -11.151743
```
```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation
#### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
```
```
python fivo.py \
python run_fivo.py \
--mode=eval \
--mode=eval \
--split=test \
--split=test \
--alsologtostderr \
--alsologtostderr \
...
@@ -108,12 +120,52 @@ python fivo.py \
...
@@ -108,12 +120,52 @@ python fivo.py \
You should see output like this:
You should see output like this:
```
```
Model restored from step 1, evaluating.
Restoring parameters from /tmp/fivo/model.ckpt-0
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
Model restored from step 0, evaluating.
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
```
```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
#### Sampling
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
```
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
```
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
You should see very little output.
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Running local_init_op.
Done running local_init_op.
```
Loading the samples with `np.load` confirms that we conditioned the model on 4
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
```
>>> import numpy as np
>>> x = np.load("/tmp/fivo/samples.npz")
>>> x[()]['prefixes'].shape
(25, 4, 88)
>>> x[()]['samples'].shape
(50, 4, 4, 88)
```
### Training on TIMIT
### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
This is very similar to training on pianoroll datasets, with just a few flags switched.
This is very similar to training on pianoroll datasets, with just a few flags switched.
```
```
python fivo.py \
python run_fivo.py \
--mode=train \
--mode=train \
--logdir=/tmp/fivo \
--logdir=/tmp/fivo \
--model=vrnn \
--model=vrnn \
...
@@ -149,6 +201,10 @@ python fivo.py \
...
@@ -149,6 +201,10 @@ python fivo.py \
--dataset_path="$TIMIT_DIR/train" \
--dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech"
--dataset_type="speech"
```
```
Evaluation and sampling are similar.
### Tests
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.