Commit 3a658979 authored by Jasmine's avatar Jasmine
Browse files

adding lfads

parent 4c7d1708
...@@ -22,6 +22,7 @@ running TensorFlow 0.12 or earlier, please ...@@ -22,6 +22,7 @@ running TensorFlow 0.12 or earlier, please
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning. - [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning.
- [lfads](lfads): sequential variational autoencoder for analyzing neuroscience data.
- [lm_1b](lm_1b): language modeling on the one billion word benchmark. - [lm_1b](lm_1b): language modeling on the one billion word benchmark.
- [namignizer](namignizer): recognize and generate names. - [namignizer](namignizer): recognize and generate names.
- [neural_gpu](neural_gpu): highly parallel neural computer. - [neural_gpu](neural_gpu): highly parallel neural computer.
......
# LFADS - Latent Factor Analysis via Dynamical Systems
This code implements the model from the paper "[LFADS - Latent Factor Analysis via Dynamical Systems](http://biorxiv.org/content/early/2017/06/20/152884)". It is a sequential variational auto-encoder designed specifically for investigating neuroscience data, but can be applied widely to any time series data. In an unsupervised setting, LFADS is able to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observation model is a loss on a probability distribution, so when LFADS processes a dataset, a denoised version of the dataset is also created. For example, if the dataset is raw spike counts, then under the negative log-likeihood loss under a Poisson distribution, the denoised data would be the inferred Poisson rates.
## Prerequisites
The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.0.1 or greater ([install](https://www.tensorflow.org/install/))
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py))
## Getting started
Before starting, run the following:
<pre>
$ export PYTHONPATH=$PYTHONPATH:/<b>path/to/your/directory</b>/lfads/
</pre>
where "path/to/your/directory" is replaced with the path to the LFADS repository (you can get this path by using the `pwd` command). This allows the nested directories to access modules from their parent directory.
## Generate synthetic data
In order to generate the synthetic datasets first, from the top-level lfads directory, run:
```sh
$ cd synth_data
$ ./run_generate_synth_data.sh
$ cd ..
```
These synthetic datasets are provided 1. to gain insight into how the LFADS algorithm operates, and 2. to give reasonable starting points for analyses you might be interested for your own data.
## Train an LFADS model
Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 8) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
```sh
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5)
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=poisson \
--do_causal_controller=false \
--batch_size=128 \
--learning_rate_init=0.01 \
--learning_rate_stop=1e-05 \
--learning_rate_decay_factor=0.95 \
--learning_rate_n_to_compare=6 \
--do_reset_learning_rate=false \
--keep_prob=0.95 \
--con_dim=128 \
--gen_dim=200 \
--ci_enc_dim=128 \
--ic_dim=64 \
--ic_enc_dim=128 \
--ic_prior_var_min=0.1 \
--gen_cell_input_weight_scale=1.0 \
--cell_weight_scale=1.0 \
--do_feed_factors_to_controller=true \
--kl_start_step=0 \
--kl_increase_steps=2000 \
--kl_ic_weight=1.0 \
--l2_con_scale=0.0 \
--l2_gen_scale=2000.0 \
--l2_start_step=0 \
--l2_increase_steps=2000 \
--ic_prior_var_scale=0.1 \
--ic_post_var_min=0.0001 \
--kl_co_weight=1.0 \
--prior_ar_nvar=0.1 \
--cell_clip_value=5.0 \
--max_ckpt_to_keep_lve=5 \
--do_train_prior_ar_atau=true \
--co_prior_var_scale=0.1 \
--csv_log=fitlog \
--feedback_factors_or_rates=factors \
--do_train_prior_ar_nvar=true \
--max_grad_norm=200.0 \
--device=gpu:0 \
--num_steps_for_gen_ic=100000000 \
--ps_nexamples_to_process=100000000 \
--checkpoint_name=lfads_vae \
--temporal_spike_jitter_width=0 \
--checkpoint_pb_load_name=checkpoint \
--inject_ext_input_to_gen=false \
--co_mean_corr_scale=0.0 \
--gen_cell_rec_weight_scale=1.0 \
--max_ckpt_to_keep=5 \
--output_filename_stem="" \
--ic_prior_var_max=0.1 \
--prior_ar_atau=10.0 \
--do_train_io_only=false
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20
# Run LFADS on multi-session RNN data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_multisession \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
--factors_dim=10
# Run LFADS on integration to bound model data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=itb_rnn \
--lfads_save_dir=/tmp/lfads_itb_rnn \
--co_dim=1 \
--factors_dim=20 \
--controller_input_lag=0
# Run LFADS on chaotic RNN data with labels
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnns_labeled \
--lfads_save_dir=/tmp/lfads_chaotic_rnns_labeled \
--co_dim=0 \
--factors_dim=20 \
--controller_input_lag=0 \
--ext_input_dim=1
```
**Tip**: If you are running LFADS on GPU and would like to run more than one model concurrently, set the `--allow_gpu_growth=True` flag on each job, otherwise one model will take up the entire GPU for performance purposes. Also, one needs to install the TensorFlow libraries with GPU support.
## Visualize a training model
To visualize training curves and various other metrics while training and LFADS model, run the following command on your model directory. To launch a tensorboard on the chaotic RNN data with input pulses, for example:
```sh
tensorboard --logdir=/tmp/lfads_chaotic_rnn_inputs_g2p5
```
## Evaluate a trained model
Once your model is finished training, there are multiple ways you can evaluate
it. Below are some sample commands to evaluate an LFADS model trained on the
chaotic rnn data with input pulses (g = 2.5). The key differences here are
setting the `--kind` flag to the appropriate mode, as well as the
`--checkpoint_pb_load_name` flag to `checkpoint_lve` and the `--batch_size` flag
(if you'd like to make it larger or smaller). All other flags should be the
same as used in training, so that the same model architecture is built.
```sh
# Take samples from posterior then average (denoising operation)
$ python run_lfads.py --kind=posterior_sample_and_average \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=1024 \
--checkpoint_pb_load_name=checkpoint_lve
# Sample from prior (generation of completely new samples)
$ python run_lfads.py --kind=prior_sample \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=50 \
--checkpoint_pb_load_name=checkpoint_lve
# Write down model parameters
$ python run_lfads.py --kind=write_model_params \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--checkpoint_pb_load_name=checkpoint_lve
```
## Contact
File any issues with the [issue tracker](https://github.com/tensorflow/models/issues). For any questions or problems, this code is maintained by [@sussillo](https://github.com/sussillo) and [@jazcollins](https://github.com/jazcollins).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import tensorflow as tf
from utils import linear, log_sum_exp
class Poisson(object):
"""Poisson distributon
Computes the log probability under the model.
"""
def __init__(self, log_rates):
""" Create Poisson distributions with log_rates parameters.
Args:
log_rates: a tensor-like list of log rates underlying the Poisson dist.
"""
self.logr = log_rates
def logp(self, bin_counts):
"""Compute the log probability for the counts in the bin, under the model.
Args:
bin_counts: array-like integer counts
Returns:
The log-probability under the Poisson models for each element of
bin_counts.
"""
k = tf.to_float(bin_counts)
# log poisson(k, r) = log(r^k * e^(-r) / k!) = k log(r) - r - log k!
# log poisson(k, r=exp(x)) = k * x - exp(x) - lgamma(k + 1)
return k * self.logr - tf.exp(self.logr) - tf.lgamma(k + 1)
def diag_gaussian_log_likelihood(z, mu=0.0, logvar=0.0):
"""Log-likelihood under a Gaussian distribution with diagonal covariance.
Returns the log-likelihood for each dimension. One should sum the
results for the log-likelihood under the full multidimensional model.
Args:
z: The value to compute the log-likelihood.
mu: The mean of the Gaussian
logvar: The log variance of the Gaussian.
Returns:
The log-likelihood under the Gaussian model.
"""
return -0.5 * (logvar + np.log(2*np.pi) + \
tf.square((z-mu)/tf.exp(0.5*logvar)))
def gaussian_pos_log_likelihood(unused_mean, logvar, noise):
"""Gaussian log-likelihood function for a posterior in VAE
Note: This function is specialized for a posterior distribution, that has the
form of z = mean + sigma * noise.
Args:
unused_mean: ignore
logvar: The log variance of the distribution
noise: The noise used in the sampling of the posterior.
Returns:
The log-likelihood under the Gaussian model.
"""
# ln N(z; mean, sigma) = - ln(sigma) - 0.5 ln 2pi - noise^2 / 2
return - 0.5 * (logvar + np.log(2 * np.pi) + tf.square(noise))
class Gaussian(object):
"""Base class for Gaussian distribution classes."""
pass
class DiagonalGaussian(Gaussian):
"""Diagonal Gaussian with different constant mean and variances in each
dimension.
"""
def __init__(self, batch_size, z_size, mean, logvar):
"""Create a diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
mean: The N-D mean of the distribution.
logvar: The N-D log variance of the diagonal distribution.
"""
size__xz = [None, z_size]
self.mean = mean # bxn already
self.logvar = logvar # bxn already
self.noise = noise = tf.random_normal(tf.shape(logvar))
self.sample = mean + tf.exp(0.5 * logvar) * noise
mean.set_shape(size__xz)
logvar.set_shape(size__xz)
self.sample.set_shape(size__xz)
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample:
return gaussian_pos_log_likelihood(self.mean, self.logvar, self.noise)
return diag_gaussian_log_likelihood(z, self.mean, self.logvar)
class LearnableDiagonalGaussian(Gaussian):
"""Diagonal Gaussian whose mean and variance are learned parameters."""
def __init__(self, batch_size, z_size, name, mean_init=0.0,
var_init=1.0, var_min=0.0, var_max=1000000.0):
"""Create a learnable diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
name: prefix name for the mean and log TF variables.
mean_init (optional): The N-D mean initialization of the distribution.
var_init (optional): The N-D variance initialization of the diagonal
distribution.
var_min (optional): The minimum value the learned variance can take in any
dimension.
var_max (optional): The maximum value the learned variance can take in any
dimension.
"""
size_1xn = [1, z_size]
size__xn = [None, z_size]
size_bx1 = tf.stack([batch_size, 1])
assert var_init > 0.0, "Problems"
assert var_max >= var_min, "Problems"
assert var_init >= var_min, "Problems"
assert var_max >= var_init, "Problems"
z_mean_1xn = tf.get_variable(name=name+"/mean", shape=size_1xn,
initializer=tf.constant_initializer(mean_init))
self.mean_bxn = mean_bxn = tf.tile(z_mean_1xn, size_bx1)
mean_bxn.set_shape(size__xn) # tile loses shape
log_var_init = np.log(var_init)
if var_max > var_min:
var_is_trainable = True
else:
var_is_trainable = False
z_logvar_1xn = \
tf.get_variable(name=(name+"/logvar"), shape=size_1xn,
initializer=tf.constant_initializer(log_var_init),
trainable=var_is_trainable)
if var_is_trainable:
z_logit_var_1xn = tf.exp(z_logvar_1xn)
z_var_1xn = tf.nn.sigmoid(z_logit_var_1xn)*(var_max-var_min) + var_min
z_logvar_1xn = tf.log(z_var_1xn)
logvar_bxn = tf.tile(z_logvar_1xn, size_bx1)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(tf.shape(logvar_bxn))
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn, self.logvar_bxn,
self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class DiagonalGaussianFromInput(Gaussian):
"""Diagonal Gaussian whose mean and variance are conditioned on other
variables.
Note: the parameters to convert from input to the learned mean and log
variance are held in this class.
"""
def __init__(self, x_bxu, z_size, name, var_min=0.0):
"""Create an input dependent diagonal Gaussian distribution.
Args:
x: The input tensor from which the mean and variance are computed,
via a linear transformation of x. I.e.
mu = Wx + b, log(var) = Mx + c
z_size: The size of the distribution.
name: The name to prefix to learned variables.
var_min (optional): Minimal variance allowed. This is an additional
way to control the amount of information getting through the stochastic
layer.
"""
size_bxn = tf.stack([tf.shape(x_bxu)[0], z_size])
self.mean_bxn = mean_bxn = linear(x_bxu, z_size, name=(name+"/mean"))
logvar_bxn = linear(x_bxu, z_size, name=(name+"/logvar"))
if var_min > 0.0:
logvar_bxn = tf.log(tf.exp(logvar_bxn) + var_min)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(size_bxn)
self.noise_bxn.set_shape([None, z_size])
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn,
self.logvar_bxn, self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class GaussianProcess:
"""Base class for Gaussian processes."""
pass
class LearnableAutoRegressive1Prior(GaussianProcess):
"""AR(1) model where autocorrelation and process variance are learned
parameters. Assumed zero mean.
"""
def __init__(self, batch_size, z_size,
autocorrelation_taus, noise_variances,
do_train_prior_ar_atau, do_train_prior_ar_nvar,
num_steps, name):
"""Create a learnable autoregressive (1) process.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
autocorrelation_taus: The auto correlation time constant of the AR(1)
process.
A value of 0 is uncorrelated gaussian noise.
noise_variances: The variance of the additive noise, *not* the process
variance.
do_train_prior_ar_atau: Train or leave as constant, the autocorrelation?
do_train_prior_ar_nvar: Train or leave as constant, the noise variance?
num_steps: Number of steps to run the process.
name: The name to prefix to learned TF variables.
"""
# Note the use of the plural in all of these quantities. This is intended
# to mark that even though a sample z_t from the posterior is thought of a
# single sample of a multidimensional gaussian, the prior is actually
# thought of as U AR(1) processes, where U is the dimension of the inferred
# input.
size_bx1 = tf.stack([batch_size, 1])
size__xu = [None, z_size]
# process variance, the variance at time t over all instantiations of AR(1)
# with these parameters.
log_evar_inits_1xu = tf.expand_dims(tf.log(noise_variances), 0)
self.logevars_1xu = logevars_1xu = \
tf.Variable(log_evar_inits_1xu, name=name+"/logevars", dtype=tf.float32,
trainable=do_train_prior_ar_nvar)
self.logevars_bxu = logevars_bxu = tf.tile(logevars_1xu, size_bx1)
logevars_bxu.set_shape(size__xu) # tile loses shape
# \tau, which is the autocorrelation time constant of the AR(1) process
log_atau_inits_1xu = tf.expand_dims(tf.log(autocorrelation_taus), 0)
self.logataus_1xu = logataus_1xu = \
tf.Variable(log_atau_inits_1xu, name=name+"/logatau", dtype=tf.float32,
trainable=do_train_prior_ar_atau)
# phi in x_t = \mu + phi x_tm1 + \eps
# phi = exp(-1/tau)
# phi = exp(-1/exp(logtau))
# phi = exp(-exp(-logtau))
phis_1xu = tf.exp(-tf.exp(-logataus_1xu))
self.phis_bxu = phis_bxu = tf.tile(phis_1xu, size_bx1)
phis_bxu.set_shape(size__xu)
# process noise
# pvar = evar / (1- phi^2)
# logpvar = log ( exp(logevar) / (1 - phi^2) )
# logpvar = logevar - log(1-phi^2)
# logpvar = logevar - (log(1-phi) + log(1+phi))
self.logpvars_1xu = \
logevars_1xu - tf.log(1.0-phis_1xu) - tf.log(1.0+phis_1xu)
self.logpvars_bxu = logpvars_bxu = tf.tile(self.logpvars_1xu, size_bx1)
logpvars_bxu.set_shape(size__xu)
# process mean (zero but included in for completeness)
self.pmeans_bxu = pmeans_bxu = tf.zeros_like(phis_bxu)
# For sampling from the prior during de-novo generation.
self.means_t = means_t = [None] * num_steps
self.logvars_t = logvars_t = [None] * num_steps
self.samples_t = samples_t = [None] * num_steps
self.gaussians_t = gaussians_t = [None] * num_steps
sample_bxu = tf.zeros_like(phis_bxu)
for t in range(num_steps):
# process variance used here to make process completely stationary
if t == 0:
logvar_pt_bxu = self.logpvars_bxu
else:
logvar_pt_bxu = self.logevars_bxu
z_mean_pt_bxu = pmeans_bxu + phis_bxu * sample_bxu
gaussians_t[t] = DiagonalGaussian(batch_size, z_size,
mean=z_mean_pt_bxu,
logvar=logvar_pt_bxu)
sample_bxu = gaussians_t[t].sample
samples_t[t] = sample_bxu
logvars_t[t] = logvar_pt_bxu
means_t[t] = z_mean_pt_bxu
def logp_t(self, z_t_bxu, z_tm1_bxu=None):
"""Compute the log-likelihood under the distribution for a given time t,
not the whole sequence.
Args:
z_t_bxu: sample to compute likelihood for at time t.
z_tm1_bxu (optional): sample condition probability of z_t upon.
Returns:
The likelihood of p_t under the model at time t. i.e.
p(z_t|z_tm1) = N(z_tm1 * phis, eps^2)
"""
if z_tm1_bxu is None:
return diag_gaussian_log_likelihood(z_t_bxu, self.pmeans_bxu,
self.logpvars_bxu)
else:
means_t_bxu = self.pmeans_bxu + self.phis_bxu * z_tm1_bxu
logp_tgtm1_bxu = diag_gaussian_log_likelihood(z_t_bxu,
means_t_bxu,
self.logevars_bxu)
return logp_tgtm1_bxu
class KLCost_GaussianGaussian(object):
"""log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian prior. See
eqn 10 and Appendix B in VAE for latter term,
http://arxiv.org/abs/1312.6114
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, zs, prior_zs):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
E_q[ln p(z_i | z_{i+1}) / q(z_i | x)
\int q(z) ln p(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_p^2) + \
sigma_q^2 / sigma_p^2 + (mean_p - mean_q)^2 / sigma_p^2)
\int q(z) ln q(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_q^2) + 1)
Args:
zs: posterior z ~ q(z|x)
prior_zs: prior zs
"""
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
kl_b = 0.0
for z, prior_z in zip(zs, prior_zs):
assert isinstance(z, Gaussian)
assert isinstance(prior_z, Gaussian)
# ln(2pi) terms cancel
kl_b += 0.5 * tf.reduce_sum(
prior_z.logvar - z.logvar
+ tf.exp(z.logvar - prior_z.logvar)
+ tf.square((z.mean - prior_z.mean) / tf.exp(0.5 * prior_z.logvar))
- 1.0, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
class KLCost_GaussianGaussianProcessSampled(object):
""" log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian process
prior via sampling.
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, post_zs, prior_z_process):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
Args:
post_zs: posterior z ~ q(z|x)
prior_z_process: prior AR(1) process
"""
assert len(post_zs) > 1, "GP is for time, need more than 1 time step."
assert isinstance(prior_z_process, GaussianProcess), "Must use GP."
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
z0_bxu = post_zs[0].sample
logq_bxu = post_zs[0].logp(z0_bxu)
logp_bxu = prior_z_process.logp_t(z0_bxu)
z_tm1_bxu = z0_bxu
for z_t in post_zs[1:]:
# posterior is independent in time, prior is not
z_t_bxu = z_t.sample
logq_bxu += z_t.logp(z_t_bxu)
logp_bxu += prior_z_process.logp_t(z_t_bxu, z_tm1_bxu)
z_tm1 = z_t_bxu
kl_bxu = logq_bxu - logp_bxu
kl_b = tf.reduce_sum(kl_bxu, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""
LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, control inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additionally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
The main data structure being passed around is a dataset. This is a dictionary
of data dictionaries.
DATASET: The top level dictionary is simply name (string -> dictionary).
The nested dictionary is the DATA DICTIONARY, which has the following keys:
'train_data' and 'valid_data', whose values are the corresponding training
and validation data with shape
ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
The data dictionary also has a few more keys:
'train_ext_input' and 'valid_ext_input', if there are know external inputs
to the system being modeled, these take on dimensions:
ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
'alignment_matrix_cxf' - If you are using multiple days data, it's possible
that one can align the channels (see manuscript). If so each dataset will
contain this matrix, which will be used for both the input adapter and the
output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag.
If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then
one can add three more fields for plotting purposes. These are 'train_truth'
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
'train_data', and 'valid_data' but represent the underlying rates of the
observations. Finally, if one needs to convert scale for plotting the true
underlying firing rates, there is the 'conversion_factor' key.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import tensorflow as tf
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
from distributions import diag_gaussian_log_likelihood
from distributions import KLCost_GaussianGaussian, Poisson
from distributions import LearnableAutoRegressive1Prior
from distributions import KLCost_GaussianGaussianProcessSampled
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
from utils import log_sum_exp, flatten
from plot_lfads import plot_lfads
class GRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
"""
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
clip_value=np.inf, collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with
ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
collections (optional): List of additonal collections variables should
belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._weight_scale = weight_scale
self._clip_value = clip_value
self._collections = collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
if inputs is not None:
xh = tf.concat(axis=1, values=[x, h])
else:
xh = h
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
2 * self._num_units,
alpha=self._weight_scale,
name="xh_2_ru",
collections=self._collections))
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
xrh = tf.concat(axis=1, values=[x, r * h])
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
collections=self._collections))
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class GenGRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
This version is specialized for the generator, but isn't as fast, so
we have two. Note this allows for l2 regularization on the recurrent
weights, but also implicitly rescales the inputs via the 1/sqrt(input)
scaling in the linear helper routine to be large magnitude, if there are
fewer inputs than recurrent state.
"""
def __init__(self, num_units, forget_bias=1.0,
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
input_collections=None, recurrent_collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with
ws being the weight scale.
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs),
with ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
input_collections (optional): List of additonal collections variables
that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables
that rec->rec weights should belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._input_weight_scale = input_weight_scale
self._rec_weight_scale = rec_weight_scale
self._clip_value = clip_value
self._input_collections = input_collections
self._rec_collections = recurrent_collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r_x = u_x = 0.0
if x is not None:
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
2 * self._num_units,
alpha=self._input_weight_scale,
do_bias=False,
name="x_2_ru",
normalized=False,
collections=self._input_collections))
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
2 * self._num_units,
do_bias=True,
alpha=self._rec_weight_scale,
name="h_2_ru",
collections=self._rec_collections))
r = r_x + r_h
u = u_x + u_h
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
c_x = 0.0
if x is not None:
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
alpha=self._input_weight_scale,
normalized=False,
collections=self._input_collections)
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
alpha=self._rec_weight_scale,
collections=self._rec_collections)
c = tf.tanh(c_x + c_rh)
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class LFADS(object):
"""LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additoinally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
"""
def __init__(self, hps, kind="train", datasets=None):
"""Create an LFADS model.
train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative
log-likelihood using IWAE error (Importance Weighed Auto-encoder).
This is the denoising operation.
prior_sample - a model for generation - sampling from priors is used
Args:
hps: The dictionary of hyper parameters.
kind: the type of model to build (see above).
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'prior_sample']
assert kind in all_kinds, 'Wrong kind'
if hps.feedback_factors_or_rates == "rates":
assert len(hps.dataset_names) == 1, \
"Multiple datasets not supported for rate feedback."
num_steps = hps.num_steps
ic_dim = hps.ic_dim
co_dim = hps.co_dim
ext_input_dim = hps.ext_input_dim
cell_class = GRU
gen_cell_class = GenGRU
def makelambda(v): # Used with tf.case
return lambda: v
# Define the data placeholder, and deal with all parts of the graph
# that are dataset dependent.
self.dataName = tf.placeholder(tf.string, shape=())
# The batch_size to be inferred from data, as normal.
# Additionally, the data_dim will be inferred as well, allowing for a
# single placeholder for all datasets, regardless of data dimension.
if hps.output_dist == 'poisson':
# Enforce correct dtype
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
"Data dtype must be int for poisson output distribution"
data_dtype = tf.int32
elif hps.output_dist == 'gaussian':
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
"Data dtype must be float for gaussian output dsitribution"
data_dtype = tf.float32
else:
assert False, "NIY"
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
[None, num_steps, None],
name="data")
self.train_step = tf.get_variable("global_step", [], tf.int64,
tf.zeros_initializer(),
trainable=False)
self.hps = hps
ndatasets = hps.ndatasets
factors_dim = hps.factors_dim
self.preds = preds = [None] * ndatasets
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
self.datasetNames = dataset_names = hps.dataset_names
self.ext_inputs = ext_inputs = None
if len(dataset_names) == 1: # single session
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
used_in_factors_dim = factors_dim
in_identity_if_poss = False
else:
used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
in_identity_if_poss = True
else: # multisession
used_in_factors_dim = factors_dim
in_identity_if_poss = False
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1]))
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True,
mat_init_value=in_mat_cxf,
identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations'])
in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b)
with tf.variable_scope("glm"):
out_identity_if_poss = False
if len(dataset_names) == 1 and \
factors_dim == hps.dataset_dims[dataset_names[0]]:
out_identity_if_poss = True
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
out_mat_cxf = None
if in_mat_cxf is not None:
out_mat_cxf = in_mat_cxf.T
if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf,
identity_if_possible=out_identity_if_poss,
normalized=False,
name="fac_2_logrates_"+name,
collections=['IO_transformations'])
out_fac_W, out_fac_b = out_fac_lin
elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf,
normalized=False,
name="fac_2_means_"+name,
collections=['IO_transformations'])
out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf,
normalized=False,
name="fac_2_logvars_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
out_fac_W = tf.concat(
axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
out_fac_b = tf.concat(
axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
else:
assert False, "NIY"
preds[d] = tf.equal(tf.constant(name), self.dataName)
data_dim = hps.dataset_dims[name]
fns_out_fac_Ws[d] = makelambda(out_fac_W)
fns_out_fac_bs[d] = makelambda(out_fac_b)
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
case_default = lambda: tf.constant([-8675309.0])
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, case_default, exclusive=True)
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, case_default, exclusive=True)
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, case_default, exclusive=True)
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, case_default, exclusive=True)
# External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0:
self.ext_input = tf.placeholder(tf.float32,
[None, num_steps, ext_input_dim],
name="ext_input")
else:
self.ext_input = None
ext_input_bxtxi = self.ext_input
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
self.batch_size = batch_size = int(hps.batch_size)
self.learning_rate = tf.Variable(float(hps.learning_rate_init),
trainable=False, name="learning_rate")
self.learning_rate_decay_op = self.learning_rate.assign(
self.learning_rate * hps.learning_rate_decay_factor)
# Dropout the data.
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
if hps.ext_input_dim > 0:
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
else:
ext_input_do_bxtxi = None
# ENCODERS
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
num_steps_to_encode):
"""Encode data for LFADS
Args:
dataset_bxtxd - the data to encode, as a 3 tensor, with dims
time x batch x data dims.
enc_cell: encoder cell
name: name of encoder
forward_or_reverse: string, encode in forward or reverse direction
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
Returns:
encoded data as a list with num_steps_to_encode items, in order
"""
if forward_or_reverse == "forward":
dstr = "_fwd"
time_fwd_or_rev = range(num_steps_to_encode)
else:
dstr = "_rev"
time_fwd_or_rev = reversed(range(num_steps_to_encode))
with tf.variable_scope(name+"_enc"+dstr, reuse=False):
enc_state = tf.tile(
tf.Variable(tf.zeros([1, enc_cell.state_size]),
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape
enc_outs = [None] * num_steps_to_encode
for i, t in enumerate(time_fwd_or_rev):
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
dataset_t_bxd = dataset_bxtxd[:,t,:]
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
in_fac_t_bxf.set_shape([None, used_in_factors_dim])
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
enc_input_t_bxfpe = tf.concat(
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
else:
enc_input_t_bxfpe = in_fac_t_bxf
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
enc_outs[t] = enc_out
return enc_outs
# Encode initial condition means and variances
# ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0)
self.ic_enc_fwd = [None] * num_steps
self.ic_enc_rev = [None] * num_steps
if ic_dim > 0:
enc_ic_cell = cell_class(hps.ic_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "forward",
hps.num_steps_for_gen_ic)
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "reverse",
hps.num_steps_for_gen_ic)
self.ic_enc_fwd = ic_enc_fwd
self.ic_enc_rev = ic_enc_rev
# Encoder control input means and variances, bi-directional encoding so:
# ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t)
self.ci_enc_fwd = [None] * num_steps
self.ci_enc_rev = [None] * num_steps
if co_dim > 0:
enc_ci_cell = cell_class(hps.ci_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "forward",
hps.num_steps)
if hps.do_causal_controller:
ci_enc_rev = None
else:
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "reverse",
hps.num_steps)
self.ci_enc_fwd = ci_enc_fwd
self.ci_enc_rev = ci_enc_rev
# STOCHASTIC LATENT VARIABLES, priors and posteriors
# (initial conditions g0, and control inputs, u_t)
# Note that zs represent all the stochastic latent variables.
with tf.variable_scope("z", reuse=False):
self.prior_zs_g0 = None
self.posterior_zs_g0 = None
self.g0s_val = None
if ic_dim > 0:
self.prior_zs_g0 = \
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
mean_init=0.0,
var_min=hps.ic_prior_var_min,
var_init=hps.ic_prior_var_scale,
var_max=hps.ic_prior_var_max)
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
ic_enc = tf.nn.dropout(ic_enc, keep_prob)
self.posterior_zs_g0 = \
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
var_min=hps.ic_post_var_min)
if kind in ["train", "posterior_sample_and_average"]:
zs_g0 = self.posterior_zs_g0
else:
zs_g0 = self.prior_zs_g0
if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
self.g0s_val = zs_g0.sample
else:
self.g0s_val = zs_g0.mean
# Priors for controller, 'co' for controller output
self.prior_zs_co = prior_zs_co = [None] * num_steps
self.posterior_zs_co = posterior_zs_co = [None] * num_steps
self.zs_co = zs_co = [None] * num_steps
self.prior_zs_ar_con = None
if co_dim > 0:
# Controller outputs
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
self.prior_zs_ar_con = prior_zs_ar_con = \
LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
autocorrelation_taus,
noise_variances,
hps.do_train_prior_ar_atau,
hps.do_train_prior_ar_nvar,
num_steps, "u_prior_ar1")
# CONTROLLER -> GENERATOR -> RATES
# (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) )
self.controller_outputs = u_t = [None] * num_steps
self.con_ics = con_state = None
self.con_states = con_states = [None] * num_steps
self.con_outs = con_outs = [None] * num_steps
self.gen_inputs = gen_inputs = [None] * num_steps
if co_dim > 0:
# gen_cell_class here for l2 penalty recurrent weights
# didn't split the cell_weight scale here, because I doubt it matters
con_cell = gen_cell_class(hps.con_dim,
input_weight_scale=hps.cell_weight_scale,
rec_weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_con_reg'])
with tf.variable_scope("con", reuse=False):
self.con_ics = tf.tile(
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), \
name="c0"),
tf.stack([batch_size, 1]))
self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
con_states[-1] = self.con_ics
gen_cell = gen_cell_class(hps.gen_dim,
input_weight_scale=hps.gen_cell_input_weight_scale,
rec_weight_scale=hps.gen_cell_rec_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_gen_reg'])
with tf.variable_scope("gen", reuse=False):
if ic_dim == 0:
self.gen_ics = tf.tile(
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
tf.stack([batch_size, 1]))
else:
self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
identity_if_possible=True,
name="g0_2_gen_ic")
self.gen_states = gen_states = [None] * num_steps
self.gen_outs = gen_outs = [None] * num_steps
gen_states[-1] = self.gen_ics
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
self.factors = factors = [None] * num_steps
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
self.rates = rates = [None] * num_steps
# rates[-1] is collected to potentially feed back to controller
with tf.variable_scope("glm", reuse=False):
if hps.output_dist == 'poisson':
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
log_rates_t0.set_shape([None, None])
rates[-1] = tf.exp(log_rates_t0) # rate
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[-1] = means_t_bxd
else:
assert False, "NIY"
# We support mulitple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
# dimension. So in the case of gaussian, for example, it'll be
# batch x (D+D), where each D dims is the mean, and then variances,
# respectively. For a distribution with 3 parameters, it would be
# batch x (D+D+D).
self.output_dist_params = dist_params = [None] * num_steps
self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z)
for t in range(num_steps):
# Controller
if co_dim > 0:
# Build inputs for controller
tlag = t - hps.controller_input_lag
if tlag < 0:
con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
else:
con_in_f_t = ci_enc_fwd[tlag]
if hps.do_causal_controller:
# If controller is causal (wrt to data generation process), then it
# cannot see future data. Thus, excluding ci_enc_rev[t] is obvious.
# Less obvious is the need to exclude factors[t-1]. This arises
# because information flows from g0 through factors to the controller
# input. The g0 encoding is backwards, so we must necessarily exclude
# the factors in order to keep the controller input purely from a
# forward encoding (however unlikely it is that
# g0->factors->controller channel might actually be used in this way).
con_in_list_t = [con_in_f_t]
else:
tlag_rev = t + hps.controller_input_lag
if tlag_rev >= num_steps:
# better than zeros
con_in_r_t = tf.zeros_like(ci_enc_rev[0])
else:
con_in_r_t = ci_enc_rev[tlag_rev]
con_in_list_t = [con_in_f_t, con_in_r_t]
if hps.do_feed_factors_to_controller:
if hps.feedback_factors_or_rates == "factors":
con_in_list_t.append(factors[t-1])
elif hps.feedback_factors_or_rates == "rates":
con_in_list_t.append(rates[t-1])
else:
assert False, "NIY"
con_in_t = tf.concat(axis=1, values=con_in_list_t)
con_in_t = tf.nn.dropout(con_in_t, keep_prob)
with tf.variable_scope("con", reuse=True if t > 0 else None):
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
posterior_zs_co[t] = \
DiagonalGaussianFromInput(con_outs[t], co_dim,
name="con_to_post_co")
if kind == "train":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_sample_and_average":
u_t[t] = posterior_zs_co[t].sample
else:
u_t[t] = prior_zs_ar_con.samples_t[t]
# Inputs to the generator (controller output + external input)
if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
if co_dim > 0:
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
else:
gen_inputs[t] = ext_input_t_bxi
else:
gen_inputs[t] = u_t[t]
# Generator
data_t_bxd = dataset_ph[:,t,:]
with tf.variable_scope("gen", reuse=True if t > 0 else None):
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
with tf.variable_scope("gen", reuse=True): # ic defined it above
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
with tf.variable_scope("glm", reuse=True if t > 0 else None):
if hps.output_dist == 'poisson':
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
log_rates_t.set_shape([None, None])
rates[t] = dist_params[t] = tf.exp(log_rates_t) # rates feed back
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[t] = means_t_bxd # rates feed back to controller
dist_params[t] = tf.concat(
axis=1, values=[means_t_bxd, tf.exp(logvars_t_bxd)])
loglikelihood_t = \
diag_gaussian_log_likelihood(data_t_bxd,
means_t_bxd, logvars_t_bxd)
else:
assert False, "NIY"
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
# Correlation of inferred inputs cost.
self.corr_cost = tf.constant(0.0)
if hps.co_mean_corr_scale > 0.0:
all_sum_corr = []
for i in range(hps.co_dim):
for j in range(i+1, hps.co_dim):
sum_corr_ij = tf.constant(0.0)
for t in range(num_steps):
u_mean_t = posterior_zs_co[t].mean
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs
# Variational Lower Bound on posterior, p(z|x), plus reconstruction cost.
# KL and reconstruction costs are normalized only by batch size, not by
# dimension, or by time steps.
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
self.kl_cost = tf.constant(0.0) # VAE KL cost
self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
self.nll_bound_vae = tf.constant(0.0)
self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
if kind in ["train", "posterior_sample_and_average"]:
kl_cost_g0_b = 0.0
kl_cost_co_b = 0.0
if ic_dim > 0:
g0_priors = [self.prior_zs_g0]
g0_posts = [self.posterior_zs_g0]
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
if co_dim > 0:
kl_cost_co_b = \
KLCost_GaussianGaussianProcessSampled(
posterior_zs_co, prior_zs_ar_con).kl_cost_b
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'reconstruction cost' is negative log likelihood
self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
# VAE error averages outside the log
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
# IWAE error averages inside the log
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
self.nll_bound_iwae = -iwae_lb_on_ll
# L2 regularization on the generator, normalized by number of parameters.
self.l2_cost = tf.constant(0.0)
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
l2_costs = []
l2_numels = []
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
tf.get_collection('l2_con_reg')]
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
for v in l2_reg_vars:
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
numel_f = tf.cast(numel, tf.float32)
l2_numels.append(numel_f)
v_l2 = tf.reduce_sum(v*v)
l2_costs.append(0.5 * l2_scale * v_l2)
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
# Compute the cost for training, part of the graph regardless.
# The KL cost can be problematic at the beginning of optimization,
# so we allow an exponential increase in weighting the KL from 0
# to 1.
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
self.kl_weight = kl_weight = \
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
self.l2_weight = l2_weight = \
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
self.timed_kl_cost = kl_weight * self.kl_cost
self.timed_l2_cost = l2_weight * self.l2_cost
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
self.cost = self.recon_cost + self.timed_kl_cost + \
self.timed_l2_cost + self.weight_corr_cost
if kind != "train":
# save every so often
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep_lve)
return
# OPTIMIZATION
if not self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
else:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
total_params = 0
for i in range(len(tvars)):
shape = tvars[i].get_shape().as_list()
print(" ", i, tvars[i].name, shape)
total_params += np.prod(shape)
print("Total model parameters: ", total_params)
grads = tf.gradients(self.cost, tvars)
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-01)
self.grads = grads
self.grad_global_norm = grad_global_norm
self.train_op = opt.apply_gradients(
zip(grads, tvars), global_step=self.train_step)
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# SUMMARIES, used only during training.
# example summary
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
name='image_tensor')
self.example_summ = tf.summary.image("LFADS example", self.example_image,
collections=["example_summaries"])
# general training summaries
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
self.grad_global_norm)
if hps.co_dim > 0:
self.atau_summ = [None] * hps.co_dim
self.pvar_summ = [None] * hps.co_dim
for c in range(hps.co_dim):
self.atau_summ[c] = \
tf.summary.scalar("AR Autocorrelation taus " + str(c),
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
self.pvar_summ[c] = \
tf.summary.scalar("AR Variances " + str(c),
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
# cost summaries, separated into different collections for
# training vs validation. We make placeholders for these, because
# even though the graph computes these costs on a per-batch basis,
# we want to report the more reliable metric of per-epoch cost.
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
collections=["train_summaries"])
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
collections=["valid_summaries"])
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
collections=["train_summaries"])
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
recon_cost_ph,
collections=["train_summaries"])
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
recon_cost_ph,
collections=["valid_summaries"])
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
collections=["train_summaries"])
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
collections=["valid_summaries"])
self.kl_cost_ph = kl_cost_ph
self.l2_cost_ph = l2_cost_ph
self.recon_cost_ph = recon_cost_ph
self.total_cost_ph = total_cost_ph
# Merged summaries, for easy coding later.
self.merged_examples = tf.summary.merge_all(key="example_summaries")
self.merged_generic = tf.summary.merge_all() # default key is 'summaries'
self.merged_train = tf.summary.merge_all(key="train_summaries")
self.merged_valid = tf.summary.merge_all(key="valid_summaries")
session = tf.get_default_session()
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
self.writer = tf.summary.FileWriter(self.logfile, session.graph)
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
keep_prob=None):
"""Build the feed dictionary, handles cases where there is no value defined.
Args:
train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices.
data_bxtxd: The data tensor
ext_input_bxtxi (optional): The external input tensor
keep_prob: The drop out keep probability.
Returns:
The feed dictionary with TF tensors as keys and data as values, for use
with tf.Session.run()
"""
feed_dict = {}
B, T, _ = data_bxtxd.shape
feed_dict[self.dataName] = train_name
feed_dict[self.dataset_ph] = data_bxtxd
if self.ext_input is not None and ext_input_bxtxi is not None:
feed_dict[self.ext_input] = ext_input_bxtxi
if keep_prob is None:
feed_dict[self.keep_prob] = self.hps.keep_prob
else:
feed_dict[self.keep_prob] = keep_prob
return feed_dict
@staticmethod
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
example_idxs=None):
"""Get a batch of data, either randomly chosen, or specified directly.
Args:
data_extxd: The data to model, numpy tensors with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions
batch_size: The size of the batch to return
example_idxs (optional): The example indices used to select examples.
Returns:
A tuple with two parts:
1. Batched data numpy tensor with shape:
batch_size x # time steps x # dimensions
2. Batched external input numpy tensor with shape:
batch_size x # time steps x # external input dims
"""
assert batch_size is not None or example_idxs is not None, "Problems"
E, T, D = data_extxd.shape
if example_idxs is None:
example_idxs = np.random.choice(E, batch_size)
ext_input_bxtxi = None
if ext_input_extxi is not None:
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
return data_extxd[example_idxs,:,:], ext_input_bxtxi
@staticmethod
def example_idxs_mod_batch_size(nexamples, batch_size):
"""Given a number of examples, E, and a batch_size, B, generate indices
[0, 1, 2, ... B-1;
[B, B+1, ... 2*B-1;
...
]
returning those indices as a 2-dim tensor shaped like E/B x B. Note that
shape is only correct if E % B == 0. If not, then an extra row is generated
so that the remainder of examples is included. The extra examples are
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
for randomized behavior.
Args:
nexamples: The number of examples to batch up.
batch_size: The size of the batch.
Returns:
2-dim tensor as described above.
"""
bmrem = batch_size - (nexamples % batch_size)
bmrem_examples = []
if bmrem < batch_size:
#bmrem_examples = np.zeros(bmrem, dtype=np.int32)
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
bmrem_examples = np.sort(ridxs)
example_idxs = range(nexamples) + list(bmrem_examples)
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
@staticmethod
def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
"""Indices 1:nexamples, randomized, in 2D form of
shape = (nexamples / batch_size) x batch_size. The remainder
is managed by drawing randomly from 1:nexamples.
Args:
nexamples: number of examples to randomize
batch_size: number of elements in batch
Returns:
The randomized, properly shaped indicies.
"""
assert nexamples > batch_size, "Problems"
bmrem = batch_size - nexamples % batch_size
bmrem_examples = []
if bmrem < batch_size:
bmrem_examples = np.random.choice(range(nexamples),
size=bmrem, replace=False)
example_idxs = range(nexamples) + list(bmrem_examples)
mixed_example_idxs = np.random.permutation(example_idxs)
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
def shuffle_spikes_in_time(self, data_bxtxd):
"""Shuffle the spikes in the temporal dimension. This is useful to
help the LFADS system avoid overfitting to individual spikes or fast
oscillations found in the data that are irrelevant to behavior. A
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
enough to pick up dynamics that you may not want.
Args:
data_bxtxd: numpy array of spike count data to be shuffled.
Returns:
S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately.
"""
B, T, N = data_bxtxd.shape
w = self.hps.temporal_spike_jitter_width
if w == 0:
return data_bxtxd
max_counts = np.max(data_bxtxd)
S_bxtxd = np.zeros([B,T,N])
# Intuitively, shuffle spike occurances, 0 or 1, but since we have counts,
# Do it over and over again up to the max count.
for mc in range(1,max_counts+1):
idxs = np.nonzero(data_bxtxd >= mc)
data_ones = np.zeros_like(data_bxtxd)
data_ones[data_bxtxd >= mc] = 1
nfound = len(idxs[0])
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
shuffle_tidxs = idxs[1].copy()
shuffle_tidxs += shuffles_incrs_in_time
# Reflect on the boundaries to not lose mass.
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
shuffle_tidxs[shuffle_tidxs > T-1] = \
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
S_bxtxd[iii] += 1
return S_bxtxd
def shuffle_and_flatten_datasets(self, datasets, kind='train'):
"""Since LFADS supports multiple datasets in the same dynamical model,
we have to be careful to use all the data in a single training epoch. But
since the datasets my have different data dimensionality, we cannot batch
examples from data dictionaries together. Instead, we generate random
batches within each data dictionary, and then randomize these batches
while holding onto the dataname, so that when it's time to feed
the graph, the correct in/out matrices can be selected, per batch.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
kind: 'train' or 'valid'
Returns:
A flat list, in which each element is a pair ('name', indices).
"""
batch_size = self.hps.batch_size
ndatasets = len(datasets)
random_example_idxs = {}
epoch_idxs = {}
all_name_example_idx_pairs = []
kind_data = kind + '_data'
for name, data_dict in datasets.items():
nexamples, ntime, data_dim = data_dict[kind_data].shape
epoch_idxs[name] = 0
random_example_idxs, _ = \
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
epoch_size = random_example_idxs.shape[0]
names = [name] * epoch_size
all_name_example_idx_pairs += zip(names, random_example_idxs)
np.random.shuffle(all_name_example_idx_pairs) # shuffle in place
return all_name_example_idx_pairs
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
"""Train the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional): The batch_size to use
do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch?
Returns:
A tuple with 6 float values:
(total cost of the epoch, epoch reconstruction cost,
epoch kl cost, KL weight used this training epoch,
total l2 cost on generator, and the corresponding weight).
"""
ops_to_eval = [self.cost, self.recon_cost,
self.kl_cost, self.kl_weight,
self.l2_cost, self.l2_weight,
self.train_op]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
kl_weight = collected_op_values[-1][3]
l2_cost = collected_op_values[-1][4]
l2_weight = collected_op_values[-1][5]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
if do_save_ckpt:
session = tf.get_default_session()
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '.ckpt')
self.seso_saver.save(session, checkpoint_path,
global_step=self.train_step)
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
kl_weight, l2_cost, l2_weight
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
do_collect=True, keep_prob=None):
"""Run the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call.
batch_size (optional): The batch_size to use
do_collect (optional): Should the routine collect all session.run
output as a list, and return it?
keep_prob (optional): The dropout keep probability.
Returns:
A list of lists, the internal list is the return for the ops for each
session.run() call. The outer list collects over the epoch.
"""
hps = self.hps
all_name_example_idx_pairs = \
self.shuffle_and_flatten_datasets(datasets, kind)
kind_data = kind + '_data'
kind_ext_input = kind + '_ext_input'
total_cost = total_recon_cost = total_kl_cost = 0.0
session = tf.get_default_session()
epoch_size = len(all_name_example_idx_pairs)
evaled_ops_list = []
for name, example_idxs in all_name_example_idx_pairs:
data_dict = datasets[name]
data_extxd = data_dict[kind_data]
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
data_extxd = self.shuffle_spikes_in_time(data_extxd)
ext_input_extxi = data_dict[kind_ext_input]
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
example_idxs=example_idxs)
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
keep_prob=keep_prob)
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
if do_collect:
evaled_ops_list.append(evaled_ops_np)
return evaled_ops_list
def summarize_all(self, datasets, summary_values):
"""Plot and summarize stuff in tensorboard.
Note that everything done in the current function is otherwise done on
a single, randomly selected dataset (except for summary_values, which are
passed in.)
Args:
datasets, the dictionary of datasets used in the study.
summary_values: These summary values are created from the training loop,
and so summarize the entire set of datasets.
"""
hps = self.hps
tr_kl_cost = summary_values['tr_kl_cost']
tr_recon_cost = summary_values['tr_recon_cost']
tr_total_cost = summary_values['tr_total_cost']
kl_weight = summary_values['kl_weight']
l2_weight = summary_values['l2_weight']
l2_cost = summary_values['l2_cost']
has_any_valid_set = summary_values['has_any_valid_set']
i = summary_values['nepochs']
session = tf.get_default_session()
train_summ, train_step = session.run([self.merged_train,
self.train_step],
feed_dict={self.l2_cost_ph:l2_cost,
self.kl_cost_ph:tr_kl_cost,
self.recon_cost_ph:tr_recon_cost,
self.total_cost_ph:tr_total_cost})
self.writer.add_summary(train_summ, train_step)
if has_any_valid_set:
ev_kl_cost = summary_values['ev_kl_cost']
ev_recon_cost = summary_values['ev_recon_cost']
ev_total_cost = summary_values['ev_total_cost']
eval_summ = session.run(self.merged_valid,
feed_dict={self.kl_cost_ph:ev_kl_cost,
self.recon_cost_ph:ev_recon_cost,
self.total_cost_ph:ev_total_cost})
self.writer.add_summary(eval_summ, train_step)
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight)
else:
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, tr_recon_cost,
tr_kl_cost, l2_cost, kl_weight, l2_weight)
if self.hps.csv_log:
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
with open(csv_file, "a") as myfile:
myfile.write(csv_outstr)
def plot_single_example(self, datasets):
"""Plot an image relating to a randomly chosen, specific example. We use
posterior sample and average by taking one example, and filling a whole
batch with that example, sample from the posterior, and then average the
quantities.
"""
hps = self.hps
all_data_names = datasets.keys()
data_name = np.random.permutation(all_data_names)[0]
data_dict = datasets[data_name]
has_valid_set = True if data_dict['valid_data'] is not None else False
cf = 1.0 # plotting concern
# posterior sample and average here
E, _, _ = data_dict['train_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
train_data_bxtxd, train_ext_input_bxtxi = \
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
example_idxs=example_idxs)
truth_train_data_bxtxd = None
if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
example_idxs=example_idxs)
cf = data_dict['conversion_factor']
# plotter does averaging
train_model_values = self.eval_model_runs_batch(data_name,
train_data_bxtxd,
train_ext_input_bxtxi,
do_average_batch=False)
train_step = train_model_values['train_steps']
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
train_ext_input_bxtxi, keep_prob=1.0)
session = tf.get_default_session()
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
self.writer.add_summary(generic_summ, train_step)
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
truth_valid_data_bxtxd = None
if has_valid_set:
E, _, _ = data_dict['valid_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
valid_data_bxtxd, valid_ext_input_bxtxi = \
self.get_batch(data_dict['valid_data'],
data_dict['valid_ext_input'],
example_idxs=example_idxs)
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
example_idxs=example_idxs)
else:
truth_valid_data_bxtxd = None
# plotter does averaging
valid_model_values = self.eval_model_runs_batch(data_name,
valid_data_bxtxd,
valid_ext_input_bxtxi,
do_average_batch=False)
example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
train_model_vals=train_model_values,
train_ext_input_bxtxi=train_ext_input_bxtxi,
train_truth_bxtxd=truth_train_data_bxtxd,
valid_bxtxd=valid_data_bxtxd,
valid_model_vals=valid_model_values,
valid_ext_input_bxtxi=valid_ext_input_bxtxi,
valid_truth_bxtxd=truth_valid_data_bxtxd,
bidx=None, cf=cf, output_dist=hps.output_dist)
example_image = np.expand_dims(example_image, axis=0)
example_summ = session.run(self.merged_examples,
feed_dict={self.example_image : example_image})
self.writer.add_summary(example_summ)
def train_model(self, datasets):
"""Train the model, print per-epoch information, and save checkpoints.
Loop over training epochs. The function that actually does the
training is train_epoch. This function iterates over the training
data, one epoch at a time. The learning rate schedule is such
that it will stay the same until the cost goes up in comparison to
the last few values, then it will drop.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
hps = self.hps
has_any_valid_set = False
for data_dict in datasets.values():
if data_dict['valid_data'] is not None:
has_any_valid_set = True
break
session = tf.get_default_session()
lr = session.run(self.learning_rate)
lr_stop = hps.learning_rate_stop
i = -1
train_costs = []
valid_costs = []
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
lowest_ev_cost = np.Inf
while True:
i += 1
do_save_ckpt = True if i % 10 ==0 else False
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
# Evaluate the validation cost, and potentially save. Note that this
# routine will not save a validation checkpoint until the kl weight and
# l2 weights are equal to 1.0.
if has_any_valid_set:
ev_total_cost, ev_recon_cost, ev_kl_cost = \
self.eval_cost_epoch(datasets, kind='valid')
valid_costs.append(ev_total_cost)
# > 1 may give more consistent results, but not the actual lowest vae.
# == 1 gives the lowest vae seen so far.
n_lve = 1
run_avg_lve = np.mean(valid_costs[-n_lve:])
# conditions for saving checkpoints:
# KL weight must have finished stepping (>=1.0), AND
# L2 weight must have finished stepping OR L2 is not being used, AND
# the current run has a lower LVE than previous runs AND
# len(valid_costs > n_lve) (not sure what that does)
if kl_weight >= 1.0 and \
(l2_weight >= 1.0 or \
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
lowest_ev_cost = run_avg_lve
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '_lve.ckpt')
self.lve_saver.save(session, checkpoint_path,
global_step=self.train_step,
latest_filename='checkpoint_lve')
# Plot and summarize.
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
'l2_weight':l2_weight, 'kl_weight':kl_weight,
'l2_cost':l2_cost}
self.summarize_all(datasets, values)
self.plot_single_example(datasets)
# Manage learning rate.
train_res = tr_total_cost
n_lr = hps.learning_rate_n_to_compare
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
_ = session.run(self.learning_rate_decay_op)
lr = session.run(self.learning_rate)
print(" Decreasing learning rate to %f." % lr)
# Force the system to run n_lr times while at this lr.
train_costs.append(np.inf)
else:
train_costs.append(train_res)
if lr < lr_stop:
print("Stopping optimization based on learning rate criteria.")
break
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
batch_size=None):
"""Evaluate the cost of the epoch.
Args:
data_dict: The dictionary of data (training and validation) used for
training and evaluation of the model, respectively.
Returns:
a 3 tuple of costs:
(epoch total cost, epoch reconstruction cost, epoch KL cost)
"""
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
keep_prob=1.0)
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
do_eval_cost=False, do_average_batch=False):
"""Returns all the goodies for the entire model, per batch.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_bxtxd: Numpy array training data with shape:
batch_size x # time steps x # dimensions
ext_input_bxtxi: Numpy array training external input with shape:
batch_size x # time steps x # external input dims
do_eval_cost (optional): If true, the IWAE (Importance Weighted
Autoencoder) log likeihood bound, instead of the VAE version.
do_average_batch (optional): average over the batch, useful for getting
good IWAE costs, and model outputs for a single data point.
Returns:
A dictionary with the outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the rates.
"""
session = tf.get_default_session()
feed_dict = self.build_feed_dict(data_name, data_bxtxd,
ext_input_bxtxi, keep_prob=1.0)
# Non-temporal signals will be batch x dim.
# Temporal signals are list length T with elements batch x dim.
tf_vals = [self.gen_ics, self.gen_states, self.factors,
self.output_dist_params]
tf_vals.append(self.cost)
tf_vals.append(self.nll_bound_vae)
tf_vals.append(self.nll_bound_iwae)
tf_vals.append(self.train_step) # not train_op!
if self.hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
if self.hps.co_dim > 0:
tf_vals.append(self.controller_outputs)
tf_vals_flat, fidxs = flatten(tf_vals)
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
if self.hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if self.hps.co_dim > 0:
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
nll_bound_vaes = nll_bound_vaes[0]
nll_bound_iwaes = nll_bound_iwaes[0]
train_steps = train_steps[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
if self.hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
post_g0_mean = post_g0_mean[0]
post_g0_logvar = post_g0_logvar[0]
if self.hps.co_dim > 0:
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
if do_average_batch:
gen_ics = np.mean(gen_ics, axis=0)
gen_states = np.mean(gen_states, axis=0)
factors = np.mean(factors, axis=0)
out_dist_params = np.mean(out_dist_params, axis=0)
if self.hps.ic_dim > 0:
prior_g0_mean = np.mean(prior_g0_mean, axis=0)
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
post_g0_mean = np.mean(post_g0_mean, axis=0)
post_g0_logvar = np.mean(post_g0_logvar, axis=0)
if self.hps.co_dim > 0:
controller_outputs = np.mean(controller_outputs, axis=0)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = out_dist_params
model_vals['costs'] = costs
model_vals['nll_bound_vaes'] = nll_bound_vaes
model_vals['nll_bound_iwaes'] = nll_bound_iwaes
model_vals['train_steps'] = train_steps
if self.hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
model_vals['post_g0_mean'] = post_g0_mean
model_vals['post_g0_logvar'] = post_g0_logvar
if self.hps.co_dim > 0:
model_vals['controller_outputs'] = controller_outputs
return model_vals
def eval_model_runs_avg_epoch(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns all the expected value for goodies for the entire model.
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs. The expected value is approximate, and
accomplished via sampling (batch_size) samples for every examples.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
for es_idx in range(E_to_process):
print("Running %d of %d." % (es_idx+1, E_to_process))
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=example_idxs)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=True)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def write_model_runs(self, datasets, output_fname=None):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
"""
hps = self.hps
kind = hps.kind
for data_name, data_dict in datasets.items():
data_tuple = [('train', data_dict['train_data'],
data_dict['train_ext_input']),
('valid', data_dict['valid_data'],
data_dict['valid_ext_input'])]
for data_kind, data_extxd, ext_input_extxi in data_tuple:
if not output_fname:
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
else:
fname = output_fname + data_name + '_' + data_kind + '_' + kind
print("Writing data for %s data and kind %s." % (data_name, data_kind))
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
ext_input_extxi)
full_fname = os.path.join(hps.lfads_save_dir, fname)
write_data(full_fname, model_runs, compression='gzip')
print("Done.")
def write_model_samples(self, dataset_name, output_fname=None):
"""Use the prior distribution to generate batch_size number of samples
from the model.
LFADS generates a number of outputs for each sample, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from.
output_fname: The name of the file in which to save the generated
samples.
"""
hps = self.hps
batch_size = hps.batch_size
print("Generating %d samples" % (batch_size))
tf_vals = [self.factors, self.gen_states, self.gen_ics,
self.cost, self.output_dist_params]
if hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
if hps.co_dim > 0:
tf_vals += [self.prior_zs_ar_con.samples_t]
tf_vals_flat, fidxs = flatten(tf_vals)
session = tf.get_default_session()
feed_dict = {}
feed_dict[self.dataName] = dataset_name
feed_dict[self.keep_prob] = 1.0
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.co_dim > 0:
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
if hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
if hps.co_dim > 0:
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = output_dist_params
model_vals['costs'] = costs.reshape(1)
if hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
if hps.co_dim > 0:
model_vals['prior_zs_ar_con'] = prior_zs_ar_con
full_fname = os.path.join(hps.lfads_save_dir, output_fname)
write_data(full_fname, model_vals, compression='gzip')
print("Done.")
@staticmethod
def eval_model_parameters(use_nested=True, include_strs=None):
"""Evaluate and return all of the TF variables in the model.
Args:
use_nested (optional): For returning values, use a nested dictoinary, based
on variable scoping, or return all variables in a flat dictionary.
include_strs (optional): A list of strings to use as a filter, to reduce the
number of variables returned. A variable name must contain at least one
string in include_strs as a sub-string in order to be returned.
Returns:
The parameters of the model. This can be in a flat
dictionary, or a nested dictionary, where the nesting is by variable
scope.
"""
all_tf_vars = tf.global_variables()
session = tf.get_default_session()
all_tf_vars_eval = session.run(all_tf_vars)
vars_dict = {}
strs = ["LFADS"]
if include_strs:
strs += include_strs
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval))
e = np.array(var_eval)
print(e, type(e))
else:
e = var_eval
vars_dict[var.name] = e
if not use_nested:
return vars_dict
var_names = vars_dict.keys()
nested_vars_dict = {}
current_dict = nested_vars_dict
for v, var_name in enumerate(var_names):
var_split_name_list = var_name.split('/')
split_name_list_len = len(var_split_name_list)
current_dict = nested_vars_dict
for p, part in enumerate(var_split_name_list):
if p < split_name_list_len - 1:
if part in current_dict:
current_dict = current_dict[part]
else:
current_dict[part] = {}
current_dict = current_dict[part]
else:
current_dict[part] = vars_dict[var_name]
return nested_vars_dict
@staticmethod
def spikify_rates(rates_bxtxd):
"""Randomly spikify underlying rates according a Poisson distribution
Args:
rates_bxtxd: a numpy tensor with shape:
Returns:
A numpy array with the same shape as rates_bxtxd, but with the event
counts.
"""
B,T,N = rates_bxtxd.shape
assert all([B > 0, N > 0]), "problems"
# Because the rates are changing, there is nesting
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
for b in range(B):
for t in range(T):
for n in range(N):
rate = rates_bxtxd[b,t,n]
count = np.random.poisson(rate)
spikes_bxtxd[b,t,n] = count
return spikes_bxtxd
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
def _plot_item(W, name, full_name, nspaces):
plt.figure()
if W.shape == ():
print(name, ": ", W)
elif W.shape[0] == 1:
plt.stem(W.T)
plt.title(full_name)
elif W.shape[1] == 1:
plt.stem(W)
plt.title(full_name)
else:
plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
plt.colorbar()
plt.title(full_name)
def all_plot(d, full_name="", exclude="", nspaces=0):
"""Recursively plot all the LFADS model parameters in the nested
dictionary."""
for k, v in d.iteritems():
this_name = full_name+"/"+k
if isinstance(v, dict):
all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
else:
if exclude == "" or exclude not in this_name:
_plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)
def plot_priors():
g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
g0s_prior_var_bxn = train_modelvals['prior_g0_var']
g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
g0s_post_var_bxn = train_modelvals['posterior_g0_var']
plt.figure(figsize=(10,4), tight_layout=True);
plt.subplot(1,2,1)
plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Mean Values')
plt.subplot(1,2,2)
plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Log Variance Values')
plt.figure(figsize=(10,10), tight_layout=True)
plt.subplot(2,2,1)
plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 means')
plt.subplot(2,2,2)
plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 means');
plt.subplot(2,2,3)
plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 variance Values')
plt.subplot(2,2,4)
plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 variance Values')
plt.figure(figsize=(10,5))
plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
plt.title('Log standard deviation of h0 means');
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
color='r', title=None):
if bidx is None:
vals_txn = np.mean(vals_bxtxn, axis=0)
else:
vals_txn = vals_bxtxn[bidx,:,:]
T, N = vals_txn.shape
if n_to_plot > N:
n_to_plot = N
plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
color=color, lw=1.0)
plt.axis('tight')
if title:
plt.title(title)
def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
truth_bxtxn=None, bidx=None, output_dist="poisson",
conversion_factor=1.0, subplot_cidx=0,
col_title=None):
n_to_plot = 10
scale = 1.0
nrows = 7
plt.subplot(nrows,2,1+subplot_cidx)
if output_dist == 'poisson':
rates = means = conversion_factor * model_vals['output_dist_params']
plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " rates (LFADS - red, Truth - black)")
elif output_dist == 'gaussian':
means_vars = model_vals['output_dist_params']
means, vars = np.split(means_vars,2, axis=2) # bxtxn
stds = np.sqrt(vars)
plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " means (LFADS - red, Truth - black)")
plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
else:
assert 'NIY'
if truth_bxtxn is not None:
plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
scale=scale)
input_title = ""
if "controller_outputs" in model_vals.keys():
input_title += " Controller Output"
plt.subplot(nrows,2,3+subplot_cidx)
u_t = model_vals['controller_outputs'][0:-1]
plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
title=col_title + input_title)
if ext_input_bxtxi is not None:
input_title += " External Input"
plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
scale=scale, title=col_title + input_title)
plt.subplot(nrows,2,5+subplot_cidx)
plot_time_series(means, bidx,
n_to_plot=n_to_plot, scale=1.0,
title=col_title + " Spikes (LFADS - red, Spikes - black)")
plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)
plt.subplot(nrows,2,7+subplot_cidx)
plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
scale=2.0, title=col_title + " Factors")
plt.subplot(nrows,2,9+subplot_cidx)
plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
color='g', scale=1.0, title=col_title + " Generator State")
if bidx is not None:
data_nxt = data_bxtxn[bidx,:,:].T
params_nxt = model_vals['output_dist_params'][bidx,:,:].T
else:
data_nxt = np.mean(data_bxtxn, axis=0).T
params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
if output_dist == 'poisson':
means_nxt = params_nxt
elif output_dist == 'gaussian': # (means+vars) x time
means_nxt = np.vsplit(params_nxt,2)[0] # get means
else:
assert "NIY"
plt.subplot(nrows,2,11+subplot_cidx)
plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Data')
plt.subplot(nrows,2,13+subplot_cidx)
plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Means')
def plot_lfads(train_bxtxd, train_model_vals,
train_ext_input_bxtxi=None, train_truth_bxtxd=None,
valid_bxtxd=None, valid_model_vals=None,
valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
bidx=None, cf=1.0, output_dist='poisson'):
# Plotting
f = plt.figure(figsize=(18,20), tight_layout=True)
plot_lfads_timeseries(train_bxtxd, train_model_vals,
train_ext_input_bxtxi,
truth_bxtxn=train_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist, col_title='Train')
plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
valid_ext_input_bxtxi,
truth_bxtxn=valid_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist,
subplot_cidx=1, col_title='Valid')
# Convert from figure to an numpy array width x height x 3 (last for RGB)
f.canvas.draw()
data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data_wxhx3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from lfads import LFADS
import numpy as np
import os
import tensorflow as tf
import re
import utils
# Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags
# session.
CHECKPOINT_PB_LOAD_NAME = "checkpoint"
CHECKPOINT_NAME = "lfads_vae"
CSV_LOG = "fitlog"
OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all
EXT_INPUT_DIM = 0
IC_DIM = 64
FACTORS_DIM = 50
IC_ENC_DIM = 128
GEN_DIM = 200
GEN_CELL_INPUT_WEIGHT_SCALE = 1.0
GEN_CELL_REC_WEIGHT_SCALE = 1.0
CELL_WEIGHT_SCALE = 1.0
BATCH_SIZE = 128
LEARNING_RATE_INIT = 0.01
LEARNING_RATE_DECAY_FACTOR = 0.95
LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
# Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0
CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/"
CO_DIM = 1
DO_CAUSAL_CONTROLLER = False
DO_FEED_FACTORS_TO_CONTROLLER = True
CONTROLLER_INPUT_LAG = 1
PRIOR_AR_AUTOCORRELATION = 10.0
PRIOR_AR_PROCESS_VAR = 0.1
DO_TRAIN_PRIOR_AR_ATAU = True
DO_TRAIN_PRIOR_AR_NVAR = True
CI_ENC_DIM = 128
CON_DIM = 128
CO_PRIOR_VAR_SCALE = 0.1
KL_INCREASE_STEPS = 2000
L2_INCREASE_STEPS = 2000
L2_GEN_SCALE = 2000.0
L2_CON_SCALE = 0.0
# scale of regularizer on time correlation of inferred inputs
CO_MEAN_CORR_SCALE = 0.0
KL_IC_WEIGHT = 1.0
KL_CO_WEIGHT = 1.0
KL_START_STEP = 0
L2_START_STEP = 0
IC_PRIOR_VAR_MIN = 0.1
IC_PRIOR_VAR_SCALE = 0.1
IC_PRIOR_VAR_MAX = 0.1
IC_POST_VAR_MIN = 0.0001 # protection from KL blowing up
flags = tf.app.flags
flags.DEFINE_string("kind", "train",
"Type of model to build {train, \
posterior_sample_and_average, \
prior_sample, write_model_params")
flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
"Type of output distribution, 'poisson' or 'gaussian'")
flags.DEFINE_boolean("allow_gpu_growth", False,
"If true, only allocate amount of memory needed for \
Session. Otherwise, use full GPU memory.")
# DATA
flags.DEFINE_string("data_dir", DATA_DIR, "Data for training")
flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM,
"Filename stem for data dictionaries.")
flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir")
flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME,
"Name of checkpoint files, use 'checkpoint_lve' for best \
error")
flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME,
"Name of checkpoint files (.ckpt appended)")
flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM,
"Name of output file (postfix will be added)")
flags.DEFINE_string("device", DEVICE,
"Which device to use (default: \"gpu:0\", can also be \
\"cpu:0\", \"gpu:1\", etc)")
flags.DEFINE_string("csv_log", CSV_LOG,
"Name of file to keep running log of fit likelihoods, \
etc (.csv appended)")
flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP,
"Max # of checkpoints to keep (rolling)")
flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS,
"Number of examples to process for posterior sample and \
average (not number of samples to average over).")
flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE,
"Max # of checkpoints to keep for lowest validation error \
models (rolling)")
flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \
inputs")
flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC,
"Number of steps to train the generator initial conditon.")
# If there are observed inputs, there are two ways to add that observed
# input to the model. The first is by treating as something to be
# inferred, and thus encoding the observed input via the encoders, and then
# input to the generator via the "inferred inputs" channel. Second, one
# can input the input directly into the generator. This has the downside
# of making the generation process strictly dependent on knowing the
# observed input for any generated trial.
flags.DEFINE_boolean("inject_ext_input_to_gen",
INJECT_EXT_INPUT_TO_GEN,
"Should observed inputs be input to model via encoders, \
or injected directly into generator?")
# CELL
# The combined recurrent and input weights of the encoder and
# controller cells are by default set to scale at ws/sqrt(#inputs),
# with ws=1.0. You can change this scaling with this parameter.
flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
# GENERATION
# Note that the dimension of the initial conditions is separated from the
# dimensions of the generator initial conditions (and a linear matrix will
# adapt the shapes if necessary). This is just another way to control
# complexity. In all likelihood, setting the ic dims to the size of the
# generator hidden state is just fine.
flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0")
# Setting the dimensions of the factors to something smaller than the data
# dimension is a way to get a reduced dimensionality representation of your
# data.
flags.DEFINE_integer("factors_dim", FACTORS_DIM,
"Number of factors from generator")
flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM,
"Cell hidden size, encoder of h0")
# Controlling the size of the generator is one way to control complexity of
# the dynamics (there is also l2, which will squeeze out unnecessary
# dynamics also). The modern deep learning approach is to make these cells
# as large as tolerable (from a waiting perspective), and then regularize
# them to death with drop out or whatever. I don't know if this is correct
# for the LFADS application or not.
flags.DEFINE_integer("gen_dim", GEN_DIM,
"Cell hidden size, generator.")
# The weights of the generator cell by default set to scale at
# ws/sqrt(#inputs), with ws=1.0. You can change ws for
# the input weights or the recurrent weights with these hyperparameters.
flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE,
"Input scaling for rec weights in generator.")
# KL DISTRIBUTIONS
# If you don't know what you are donig here, please leave alone, the
# defaults should be fine for most cases, irregardless of other parameters.
#
# If you don't want the prior variance to be learned, set the
# following values to the same thing: ic_prior_var_min,
# ic_prior_var_scale, ic_prior_var_max. The prior mean will be
# learned regardless.
flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN,
"Minimum variance in posterior h0 codes.")
flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE,
"Variance of ic prior distribution")
flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX,
"Maximum variance of IC prior distribution.")
# If you really want to limit the information from encoder to decoder,
# Increase ic_post_var_min above 0.0.
flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN,
"Minimum variance of IC posterior distribution.")
flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
"Variance of control input prior distribution.")
flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
"Initial autocorrelation of AR(1) priors.")
flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
"Initial noise variance for AR(1) priors.")
flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
"Is the value for atau an init, or the constant value?")
flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR,
"Is the value for noise variance an init, or the constant \
value?")
# CONTROLLER
# This parameter critically controls whether or not there is a controller
# (along with controller encoders placed into the LFADS graph. If CO_DIM >
# 1, that means there is a 1 dimensional controller outputs, if equal to 0,
# then no controller.
flags.DEFINE_integer("co_dim", CO_DIM,
"Number of control net outputs (>0 builds that graph).")
# The controller will be more powerful if it can see the encoding of the entire
# trial. However, this allows the controller to create inferred inputs that are
# acausal with respect to the actual data generation process. E.g. the data
# generator could have an input at time t, but the controller, after seeing the
# entirety of the trial could infer that the input is coming a little before
# time t, because there are no restrictions on the data the controller sees.
# One can force the controller to be causal (with respect to perturbations in
# the data generator) so that it only sees forward encodings of the data at time
# t that originate at times before or at time t. One can also control the data
# the controller sees by using an input lag (forward encoding at time [t-tlag]
# for controller input at time t. The same can be done in the reverse direction
# (controller input at time t from reverse encoding at time [t+tlag], in the
# case of an acausal controller). Setting this lag > 0 (even lag=1) can be a
# powerful way of avoiding very spiky decodes. Finally, one can manually control
# whether the factors at time t-1 are fed to the controller at time t.
#
# If you don't care about any of this, and just want to smooth your data, set
# do_causal_controller = False
# do_feed_factors_to_controller = True
# causal_input_lag = 0
flags.DEFINE_boolean("do_causal_controller",
DO_CAUSAL_CONTROLLER,
"Restrict the controller create only causal inferred \
inputs?")
# Strictly speaking, feeding either the factors or the rates to the controller
# violates causality, since the g0 gets to see all the data. This may or may not
# be only a theoretical concern.
flags.DEFINE_boolean("do_feed_factors_to_controller",
DO_FEED_FACTORS_TO_CONTROLLER,
"Should factors[t-1] be input to controller at time t?")
flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
"Feedback the factors or the rates to the controller? \
Acceptable values: 'factors' or 'rates'.")
flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
"Time lag on the encoding to controller t-lag for \
forward, t+lag for reverse.")
flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM,
"Cell hidden size, encoder of control inputs")
flags.DEFINE_integer("con_dim", CON_DIM,
"Cell hidden size, controller")
# OPTIMIZATION
flags.DEFINE_integer("batch_size", BATCH_SIZE,
"Batch size to use during training.")
flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT,
"Learning rate initial value")
flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR,
"Learning rate decay, decay by this fraction every so \
often.")
flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP,
"The lr is adaptively reduced, stop training at this value.")
# Rather put the learning rate on an exponentially decreasiong schedule,
# the current algorithm pays attention to the learning rate, and if it
# isn't regularly decreasing, it will decrease the learning rate. So far,
# it works fine, though it is not perfect.
flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE,
"Number of previous costs current cost has to be worse \
than, to lower learning rate.")
# This sets a value, above which, the gradients will be clipped. This hp
# is extremely useful to avoid an infrequent, but highly pathological
# problem whereby the gradient is so large that it destroys the
# optimziation by setting parameters too large, leading to a vicious cycle
# that ends in NaNs. If it's too large, it's useless, if it's too small,
# it essentially becomes the learning rate. It's pretty insensitive, though.
flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM,
"Max norm of gradient before clipping.")
# If your optimizations start "NaN-ing out", reduce this value so that
# the values of the network don't grow out of control. Typically, once
# this parameter is set to a reasonable value, one stops having numerical
# problems.
flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE,
"Max value recurrent cell can take before being clipped.")
# This flag is used for an experiment where one sees if training a model with
# many days data can be used to learn the dynamics from a held-out days data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \
affine functions.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
# OVERFITTING
# Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors.
flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.")
# It appears that the system will happily fit spikes (blessing or
# curse, depending). You may not want this. Jittering the spikes a
# bit will help (-/+ bin size, as specified here).
flags.DEFINE_integer("temporal_spike_jitter_width",
TEMPORAL_SPIKE_JITTER_WIDTH,
"Shuffle spikes around this window.")
# General note about helping ascribe controller inputs vs dynamics:
#
# If controller is heavily penalized, then it won't have any output.
# If dynamics are heavily penalized, then generator won't make
# dynamics. Note this l2 penalty is only on the recurrent portion of
# the RNNs, as dropout is also available, penalizing the feed-forward
# connections.
flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE,
"L2 regularization cost for the generator only.")
flags.DEFINE_float("l2_con_scale", L2_CON_SCALE,
"L2 regularization cost for the controller only.")
flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE,
"Cost of correlation (thru time)in the means of \
controller output.")
# UNDERFITTING
# If the primary task of LFADS is "filtering" of data and not
# generation, then it is possible that the KL penalty is too strong.
# Empirically, we have found this to be the case. So we add a
# hyperparameter in front of the the two KL terms (one for the initial
# conditions to the generator, the other for the controller outputs).
# You should always think of the the default values as 1.0, and that
# leads to a standard VAE formulation whereby the numbers that are
# optimized are a lower-bound on the log-likelihood of the data. When
# these 2 HPs deviate from 1.0, one cannot make any statement about
# what those LL lower bounds mean anymore, and they cannot be compared
# (AFAIK).
flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT,
"Strength of KL weight on initial conditions KL penatly.")
flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT,
"Strength of KL weight on controller output KL penalty.")
# Sometimes the task can be sufficiently hard to learn that the
# optimizer takes the 'easy route', and simply minimizes the KL
# divergence, setting it to near zero, and the optimization gets
# stuck. These two parameters will help avoid that by by getting the
# optimization to 'latch' on to the main optimization, and only
# turning in the regularizers later.
flags.DEFINE_integer("kl_start_step", KL_START_STEP,
"Start increasing weight after this many steps.")
# training passes, not epochs, increase by 0.5 every kl_increase_steps
flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS,
"Increase weight of kl cost to avoid local minimum.")
# Same story for l2 regularizer. One wants a simple generator, for scientific
# reasons, but not at the expense of hosing the optimization.
flags.DEFINE_integer("l2_start_step", L2_START_STEP,
"Start increasing l2 weight after this many steps.")
flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS,
"Increase weight of l2 cost to avoid local minimum.")
FLAGS = flags.FLAGS
def build_model(hps, kind="train", datasets=None):
"""Builds a model from either random initialization, or saved parameters.
Args:
hps: The hyper parameters for the model.
kind: (optional) The kind of model to build. Training vs inference require
different graphs.
datasets: The datasets structure (see top of lfads.py).
Returns:
an LFADS model.
"""
build_kind = kind
if build_kind == "write_model_params":
build_kind = "train"
with tf.variable_scope("LFADS", reuse=None):
model = LFADS(hps, kind=build_kind, datasets=datasets)
if not os.path.exists(hps.lfads_save_dir):
print("Save directory %s does not exist, creating it." % hps.lfads_save_dir)
os.makedirs(hps.lfads_save_dir)
cp_pb_ln = hps.checkpoint_pb_load_name
cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln
if cp_pb_ln == 'checkpoint':
print("Loading latest training checkpoint in: ", hps.lfads_save_dir)
saver = model.seso_saver
elif cp_pb_ln == 'checkpoint_lve':
print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir)
saver = model.lve_saver
else:
print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir)
saver = model.seso_saver
ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir,
latest_filename=cp_pb_ln)
session = tf.get_default_session()
print("ckpt: ", ckpt)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
if kind in ["posterior_sample_and_average", "prior_sample",
"write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \
initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?")
tf.global_variables_initializer().run()
if ckpt:
train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group()
else:
train_step_str = '-0'
fname = 'hyperparameters' + train_step_str + '.txt'
hp_fname = os.path.join(hps.lfads_save_dir, fname)
hps_for_saving = jsonify_dict(hps)
utils.write_data(hp_fname, hps_for_saving, use_json=True)
return model
def jsonify_dict(d):
"""Turns python booleans into strings so hps dict can be written in json.
Creates a shallow-copied dictionary first, then accomplishes string
conversion.
Args:
d: hyperparameter dictionary
Returns: hyperparameter dictionary with bool's as strings
"""
d2 = d.copy() # shallow copy is fine by assumption of d being shallow
def jsonify_bool(boolean_value):
if boolean_value:
return "true"
else:
return "false"
for key in d2.keys():
if isinstance(d2[key], bool):
d2[key] = jsonify_bool(d2[key])
return d2
def build_hyperparameter_dict(flags):
"""Simple script for saving hyper parameters. Under the hood the
flags structure isn't a dictionary, so it has to be simplified since we
want to be able to view file as text.
Args:
flags: From tf.app.flags
Returns:
dictionary of hyper parameters (ignoring other flag types).
"""
d = {}
# Data
d['output_dist'] = flags.output_dist
d['data_dir'] = flags.data_dir
d['lfads_save_dir'] = flags.lfads_save_dir
d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name
d['checkpoint_name'] = flags.checkpoint_name
d['output_filename_stem'] = flags.output_filename_stem
d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep
d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve
d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process
d['ext_input_dim'] = flags.ext_input_dim
d['data_filename_stem'] = flags.data_filename_stem
d['device'] = flags.device
d['csv_log'] = flags.csv_log
d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic
d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen
# Cell
d['cell_weight_scale'] = flags.cell_weight_scale
# Generation
d['ic_dim'] = flags.ic_dim
d['factors_dim'] = flags.factors_dim
d['ic_enc_dim'] = flags.ic_enc_dim
d['gen_dim'] = flags.gen_dim
d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale
d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale
# KL distributions
d['ic_prior_var_min'] = flags.ic_prior_var_min
d['ic_prior_var_scale'] = flags.ic_prior_var_scale
d['ic_prior_var_max'] = flags.ic_prior_var_max
d['ic_post_var_min'] = flags.ic_post_var_min
d['co_prior_var_scale'] = flags.co_prior_var_scale
d['prior_ar_atau'] = flags.prior_ar_atau
d['prior_ar_nvar'] = flags.prior_ar_nvar
d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau
d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar
# Controller
d['do_causal_controller'] = flags.do_causal_controller
d['controller_input_lag'] = flags.controller_input_lag
d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller
d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates
d['co_dim'] = flags.co_dim
d['ci_enc_dim'] = flags.ci_enc_dim
d['con_dim'] = flags.con_dim
d['co_mean_corr_scale'] = flags.co_mean_corr_scale
# Optimization
d['batch_size'] = flags.batch_size
d['learning_rate_init'] = flags.learning_rate_init
d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor
d['learning_rate_stop'] = flags.learning_rate_stop
d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare
d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
# Overfitting
d['keep_prob'] = flags.keep_prob
d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width
d['l2_gen_scale'] = flags.l2_gen_scale
d['l2_con_scale'] = flags.l2_con_scale
# Underfitting
d['kl_ic_weight'] = flags.kl_ic_weight
d['kl_co_weight'] = flags.kl_co_weight
d['kl_start_step'] = flags.kl_start_step
d['kl_increase_steps'] = flags.kl_increase_steps
d['l2_start_step'] = flags.l2_start_step
d['l2_increase_steps'] = flags.l2_increase_steps
return d
class hps_dict_to_obj(dict):
"""Helper class allowing us to access hps dictionary more easily."""
def __getattr__(self, key):
if key in self:
return self[key]
else:
assert False, ("%s does not exist." % key)
def __setattr__(self, key, value):
self[key] = value
def train(hps, datasets):
"""Train the LFADS model.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
model = build_model(hps, kind="train", datasets=datasets)
if hps.do_reset_learning_rate:
sess = tf.get_default_session()
sess.run(model.learning_rate.initializer)
model.train_model(datasets)
def write_model_runs(hps, datasets, output_fname=None):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The rates for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs.
"""
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_runs(datasets, output_fname)
def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
"""Use the prior distribution to generate samples from the model.
Generates batch_size number of samples (set through FLAGS).
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from. Only a concern with models trained on
multi-session data. By default, uses the first dataset in the data dict.
output_fname: The name prefix of the file in which to save the generated
samples.
"""
if not output_fname:
output_fname = "model_runs_" + hps.kind
else:
output_fname = output_fname + "model_runs_" + hps.kind
if not dataset_name:
dataset_name = datasets.keys()[0]
else:
if dataset_name not in datasets.keys():
raise ValueError("Invalid dataset name '%s'."%(dataset_name))
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_samples(dataset_name, output_fname)
def write_model_parameters(hps, output_fname=None, datasets=None):
"""Save all the model parameters
Save all the parameters to hps.lfads_save_dir.
Args:
hps: The dictionary of hyperparameters.
output_fname: The prefix of the file in which to save the generated
samples.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
if not output_fname:
output_fname = "model_params"
else:
output_fname = output_fname + "_model_params"
fname = os.path.join(hps.lfads_save_dir, output_fname)
print("Writing model parameters to: ", fname)
# save the optimizer params as well
model = build_model(hps, kind="write_model_params", datasets=datasets)
model_params = model.eval_model_parameters(use_nested=False,
include_strs="LFADS")
utils.write_data(fname, model_params, compression=None)
print("Done.")
def clean_data_dict(data_dict):
"""Add some key/value pairs to the data dict, if they are missing.
Args:
data_dict - dictionary containing data for LFADS
Returns:
data_dict with some keys filled in, if they are absent.
"""
keys = ['train_truth', 'train_ext_input', 'valid_data',
'valid_truth', 'valid_ext_input', 'valid_train']
for k in keys:
if k not in data_dict:
data_dict[k] = None
return data_dict
def load_datasets(data_dir, data_filename_stem):
"""Load the datasets from a specified directory.
Example files look like
>data_dir/my_dataset_first_day
>data_dir/my_dataset_second_day
If my_dataset (filename) stem is in the directory, the read routine will try
and load it. The datasets dictionary will then look like
dataset['first_day'] -> (first day data dictionary)
dataset['second_day'] -> (first day data dictionary)
Args:
data_dir: The directory from which to load the datasets.
data_filename_stem: The stem of the filename for the datasets.
Returns:
datasets: a dataset dictionary, with one name->data dictionary pair for
each dataset file.
"""
print("Reading data from ", data_dir)
datasets = utils.read_datasets(data_dir, data_filename_stem)
for k, data_dict in datasets.items():
datasets[k] = clean_data_dict(data_dict)
train_total_size = len(data_dict['train_data'])
if train_total_size == 0:
print("Did not load training set.")
else:
print("Found training set with number examples: ", train_total_size)
valid_total_size = len(data_dict['valid_data'])
if valid_total_size == 0:
print("Did not load validation set.")
else:
print("Found validation set with number examples: ", valid_total_size)
return datasets
def main(_):
"""Get this whole shindig off the ground."""
d = build_hyperparameter_dict(FLAGS)
hps = hps_dict_to_obj(d) # hyper parameters
kind = FLAGS.kind
# Read the data, if necessary.
train_set = valid_set = None
if kind in ["train", "posterior_sample_and_average", "prior_sample",
"write_model_params"]:
datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
else:
raise ValueError('Kind {} is not supported.'.format(kind))
# infer the dataset names and dataset dimensions from the loaded files
hps.kind = kind # needs to be added here, cuz not saved as hyperparam
hps.dataset_names = []
hps.dataset_dims = {}
for key in datasets:
hps.dataset_names.append(key)
hps.dataset_dims[key] = datasets[key]['data_dim']
# also store down the dimensionality of the data
# - just pull from one set, required to be same for all sets
hps.num_steps = datasets.values()[0]['num_steps']
hps.ndatasets = len(hps.dataset_names)
if hps.num_steps_for_gen_ic > hps.num_steps:
hps.num_steps_for_gen_ic = hps.num_steps
# Build and run the model, for varying purposes.
config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
if FLAGS.allow_gpu_growth:
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
with sess.as_default():
with tf.device(hps.device):
if kind == "train":
train(hps, datasets)
elif kind == "posterior_sample_and_average":
write_model_runs(hps, datasets, hps.output_filename_stem)
elif kind == "prior_sample":
write_model_samples(hps, datasets, hps.output_filename_stem)
elif kind == "write_model_params":
write_model_parameters(hps, hps.output_filename_stem, datasets)
else:
assert False, ("Kind %s is not implemented. " % kind)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf # used for flags here
from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
matplotlib.rcParams['image.interpolation'] = 'nearest'
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 40,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("input_magnitude", 20.0,
"For the input case, what is the value of the input?")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
# Note that with N small, (as it is 25 above), the finite size effects
# will have pretty dramatic effects on the dynamics of the random RNN.
# If you want more complex dynamics, you'll have to run the script a
# lot, or increase N (or g).
# Getting hard vs. easy data can be a little stochastic, so we set the seed.
# Pull out some commonly used parameters.
# These are user parameters (configuration)
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
S = FLAGS.S
input_magnitude = FLAGS.input_magnitude
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
# S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations.
ndatasets = N/S # ok if rounded down
train_percentage = FLAGS.train_percentage
ntime_steps = int(T / FLAGS.dt)
# End of user parameters
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper.
if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nspikifications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nspikifications should be integral number.'
# Initial condition generation, and condition label generation. This
# happens outside of the dataset loop, so that all datasets have the
# same conditions, which is similar to a neurophys setup.
condition_number = 0
x0s = []
condition_labels = []
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) # replicate x0 nspikifications times
# replicate the condition label nspikifications times
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
# Containers for storing data across data.
datasets = {}
for n in range(ndatasets):
print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all
# spikifications this allows the random state for rate generation to be
# independent of n_spikifications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N:
dataset_name += '_n' + str(n+1)
# Sample neuron subsets. The assumption is the PC axes of the RNN
# are not unit aligned, so sampling units is adequate to sample all
# the high-variance PCs.
P_sxn = np.eye(S,N)
for m in range(n):
P_sxn = np.roll(P_sxn, S, axis=1)
if input_magnitude > 0.0:
# time of "hits" randomly chosen between [1/4 and 3/4] of total time
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
else:
input_times = None
rates, x0s, inputs = \
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude,
input_times=input_times)
spikes = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
# Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = \
split_list_by_inds(spikes, train_inds, valid_inds)
input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \
split_list_by_inds(condition_labels, train_inds, valid_inds)
input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, spikes, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the spike
# trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : rnn['dt'],
'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train,
'input_times_valid' : input_times_valid,
'P_sxn' : P_sxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn['conversion_factor']}
datasets[dataset_name] = data
if S < N:
# Note that this isn't necessary for this synthetic example, but
# it's useful to see how the input factor matrices were initialized
# for actual neurophysiology data.
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
# Write out the datasets.
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf
from utils import write_datasets
from synthetic_data_utils import normalize_rates
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "itb_rnn",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0,
"Map 1.0 of RNN to a spikes per second")
flags.DEFINE_float("u_std", 0.25,
"Std dev of input to integration to bound model")
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
"""Path to directory with checkpoints of model
trained on integration to bound task. Currently this
is a placeholder which tells the code to grab the
checkpoint that is provided with the code
(in /trained_itb/..). If you have your own checkpoint
you would like to restore, you would point it to
that path.""")
FLAGS = flags.FLAGS
class IntegrationToBoundModel:
def __init__(self, N):
scale = 0.8 / float(N**0.5)
self.N = N
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
self.b_1xn = tf.Variable(tf.zeros([1, N]))
self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
self.bro_o = tf.Variable(tf.zeros([1]))
def call(self, h_tm1_bxn, u_bx1):
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
h_t_bxn = tf.nn.tanh(act_t_bxn)
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
return z_t, h_t_bxn
def get_data_batch(batch_size, T, rng, u_std):
u_bxt = rng.randn(batch_size, T) * u_std
running_sum_b = np.zeros([batch_size])
labels_bxt = np.zeros([batch_size, T])
for t in xrange(T):
running_sum_b += u_bxt[:, t]
labels_bxt[:, t] += running_sum_b
labels_bxt = np.clip(labels_bxt, -1, 1)
return u_bxt, labels_bxt
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial
model = IntegrationToBoundModel(N)
inputs_ph_t = [tf.placeholder(tf.float32,
shape=[None, 1]) for _ in range(ntimesteps)]
state = tf.zeros([batch_size, N])
saver = tf.train.Saver()
P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections
# unroll RNN for T timesteps
outputs_t = []
states_t = []
for inp in inputs_ph_t:
output, state = model.call(state, inp)
outputs_t.append(output)
states_t.append(state)
with tf.Session() as sess:
# restore the latest model ckpt
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
dir_path = os.path.dirname(os.path.realpath(__file__))
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
else:
model_checkpoint_path = FLAGS.checkpoint_path
try:
saver.restore(sess, model_checkpoint_path)
print ('Model restored from', model_checkpoint_path)
except:
assert False, ("No checkpoints to restore from, is the path %s correct?"
%model_checkpoint_path)
# generate data for trials
data_e = []
u_e = []
outs_e = []
for c in range(C):
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
feed_dict = {}
for t in xrange(ntimesteps):
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
feed_dict=feed_dict)
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications):
data_e.append(r_sxt)
u_e.append(u_1xt)
outs_e.append(outputs_t_bxn)
truth_data_e = normalize_rates(data_e, E, N)
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds,
valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
train_inds,
valid_inds)
data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
train_inds,
valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)
# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
train_inds,
valid_inds)
train_outputs_u = np.array(train_outputs_u)
valid_outputs_u = np.array(valid_outputs_u)
data = { 'train_truth': data_train_truth,
'valid_truth': data_valid_truth,
'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate,
'train_inputs_u': train_inputs_u,
'valid_inputs_u': valid_inputs_u,
'train_outputs_u': train_outputs_u,
'valid_outputs_u': valid_outputs_u,
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import numpy as np
from synthetic_data_utils import generate_data, generate_rnn
from synthetic_data_utils import get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import tensorflow as tf
from utils import write_datasets
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "conditioned_rnn_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
np.random.RandomState(seed=FLAGS.synth_data_seed+2)]
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
nspikifications = FLAGS.nspikifications
E = nspikifications * C
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnns = [rnn_a, rnn_b]
# pick which RNN is used on each trial
rnn_to_use = rng.randint(2, size=E)
ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1)
ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper
x0s = []
condition_labels = []
condition_number = 0
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications))
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
P_nxn = rng.randn(N, N) / np.sqrt(N)
# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])
rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])
# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
if rnn_to_use[trial] == 0:
rates.append(rates_a[trial])
spikes.append(spikes_a[trial])
else:
rates.append(rates_b[trial])
spikes.append(spikes_b[trial])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(
ext_input, train_inds, valid_inds)
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
# add train_ext_input and valid_ext input
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn_a['conversion_factor']}
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
Args:
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
Returns:
the dictionary of these parameters, plus some others.
"""
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
input_times=None):
""" Generates data from an randomly initialized RNN.
Args:
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
Returns:
A list of length E of NxT tensors of the network being run.
"""
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt = np.dot(P_sxn, r_nxt)
inputs_e.append(u_uxt)
data_e.append(r_sxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_data_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
spikifies_data_e = []
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(np.int)
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
spikes_e.append(data_s)
return spikes_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications:
valid_inds.append(i)
else:
train_inds.append(i)
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
inds2.
Args:
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
"""
if data is None or len(data) == 0:
return [], []
else:
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
Args:
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
Returns:
a numpy 3-tensor with dimensions a x c x b
"""
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
system.
Args:
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
Returns:
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
"""
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
avg_data_all[name][cname].T
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.figure()
plt.plot(np.log10(evals_n), '-x')
plt.figure()
plt.subplot(311)
plt.imshow(all_data_pca_pxtc)
plt.colorbar()
plt.subplot(312)
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
plt.colorbar()
plt.subplot(313)
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
plt.colorbar()
import pdb
pdb.set_trace()
return datasets
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import json
import numpy as np
import tensorflow as tf
def log_sum_exp(x_k):
"""Computes log \sum exp in a numerically stable way.
log ( sum_i exp(x_i) )
log ( sum_i exp(x_i - m + m) ), with m = max(x_i)
log ( sum_i exp(x_i - m)*exp(m) )
log ( sum_i exp(x_i - m) + m
Args:
x_k - k -dimensional list of arguments to log_sum_exp.
Returns:
log_sum_exp of the arguments.
"""
m = tf.reduce_max(x_k)
x1_k = x_k - m
u_k = tf.exp(x1_k)
z = tf.reduce_sum(u_k)
return tf.log(z) + m
def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
x: input The tensor to tranformation.
out_size: The integer size of non-batch output dimension.
do_bias (optional): Add a learnable bias vector to the operation.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the tensorflow op that yields y.
"""
in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
# Sometimes linear layers are nothing more than size adapters.
return tf.identity(x, name=(wname+'_ident'))
W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha,
normalized=normalized, name=name, collections=collections)
if do_bias:
return tf.matmul(x, W) + b
else:
return tf.matmul(x, W)
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
identity_if_possible=False, normalized=False,
name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation.
mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the pair (W, b).
"""
if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
raise ValueError(
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
return (tf.constant(np.eye(in_size).astype(np.float32)),
tf.zeros(in_size))
# Note the use of get_variable vs. tf.Variable. this is because get_variable
# does not allow the initialization of the variable with a value.
if normalized:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
b_collections += collections
bname = (name + "/b") if name else "/b"
b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(),
collections=b_collections)
else:
b = None
return (w, b)
def write_data(data_fname, data_dict, use_json=False, compression=None):
"""Write data in HD5F format.
Args:
data_fname: The filename of teh file in which to write the data.
data_dict: The dictionary of data to write. The keys are strings
and the values are numpy arrays.
use_json (optional): human readable format for simple items
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
dir_name = os.path.dirname(data_fname)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
if use_json:
the_file = open(data_fname,'w')
json.dump(data_dict, the_file)
the_file.close()
else:
try:
with h5py.File(data_fname, 'w') as hf:
for k, v in data_dict.items():
clean_k = k.replace('/', '_')
if clean_k is not k:
print('Warning: saving variable with name: ', k, ' as ', clean_k)
else:
print('Saving variable with name: ', clean_k)
hf.create_dataset(clean_k, data=v, compression=compression)
except IOError:
print("Cannot open %s for writing.", data_fname)
raise
def read_data(data_fname):
""" Read saved data in HDF5 format.
Args:
data_fname: The filename of the file from which to read the data.
Returns:
A dictionary whose keys will vary depending on dataset (but should
always contain the keys 'train_data' and 'valid_data') and whose
values are numpy arrays.
"""
try:
with h5py.File(data_fname, 'r') as hf:
data_dict = {k: np.array(v) for k, v in hf.items()}
return data_dict
except IOError:
print("Cannot open %s for reading." % data_fname)
raise
def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None):
"""Write datasets in HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
dataset_dict: The dictionary of datasets. The keys are strings
and the values data dictionaries (str -> numpy arrays) associations.
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
full_name_stem = os.path.join(data_path, data_fname_stem)
for s, data_dict in dataset_dict.items():
write_data(full_name_stem + "_" + s, data_dict, compression=compression)
def read_datasets(data_path, data_fname_stem):
"""Read dataset sin HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
"""
dataset_dict = {}
fnames = os.listdir(data_path)
print ('loading data from ' + data_path + ' with stem ' + data_fname_stem)
for fname in fnames:
if fname.startswith(data_fname_stem):
data_dict = read_data(os.path.join(data_path,fname))
idx = len(data_fname_stem) + 1
key = fname[idx:]
data_dict['data_dim'] = data_dict['train_data'].shape[2]
data_dict['num_steps'] = data_dict['train_data'].shape[1]
dataset_dict[key] = data_dict
if len(dataset_dict) == 0:
raise ValueError("Failed to load any datasets, are you sure that the "
"'--data_dir' and '--data_filename_stem' flag values "
"are correct?")
print (str(len(dataset_dict)) + ' datasets loaded')
return dataset_dict
# NUMPY utility functions
def list_t_bxn_to_list_b_txn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors of length B list of TxN numpy
tensors.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
The length B list of TxN numpy tensors.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_b_txn = []
for b in range(B):
values_pb_txn = np.zeros([T,N])
for t in range(T):
values_pb_txn[t,:] = values_t_bxn[t][b,:]
values_b_txn.append(values_pb_txn)
return values_b_txn
def list_t_bxn_to_tensor_bxtxn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors to single numpy tensor with
shape BxTxN.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
values_bxtxn: The BxTxN numpy tensor.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_bxtxn = np.zeros([B,T,N])
for t in range(T):
values_bxtxn[:,t,:] = values_t_bxn[t]
return values_bxtxn
def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn):
"""Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors
with shape BxT.
Args:
tensor_bxtxn: The BxTxN numpy tensor.
Returns:
A length T list of numpy tensors with shape BxT.
"""
values_t_bxn = []
B, T, N = tensor_bxtxn.shape
for t in range(T):
values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:]))
return values_t_bxn
def flatten(list_of_lists):
"""Takes a list of lists and returns a list of the elements.
Args:
list_of_lists: List of lists.
Returns:
flat_list: Flattened list.
flat_list_idxs: Flattened list indices.
"""
flat_list = []
flat_list_idxs = []
start_idx = 0
for item in list_of_lists:
if isinstance(item, list):
flat_list += item
l = len(item)
idxs = range(start_idx, start_idx+l)
start_idx = start_idx+l
else: # a value
flat_list.append(item)
idxs = [start_idx]
start_idx += 1
flat_list_idxs.append(idxs)
return flat_list, flat_list_idxs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment