Commit f906646c authored by Duc Nguyen's avatar Duc Nguyen Committed by GitHub
Browse files

Merge branch 'master' into patch-2

parents 2f3666ed a6df5573
...@@ -22,6 +22,7 @@ running TensorFlow 0.12 or earlier, please ...@@ -22,6 +22,7 @@ running TensorFlow 0.12 or earlier, please
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning. - [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning.
- [lfads](lfads): sequential variational autoencoder for analyzing neuroscience data.
- [lm_1b](lm_1b): language modeling on the one billion word benchmark. - [lm_1b](lm_1b): language modeling on the one billion word benchmark.
- [namignizer](namignizer): recognize and generate names. - [namignizer](namignizer): recognize and generate names.
- [neural_gpu](neural_gpu): highly parallel neural computer. - [neural_gpu](neural_gpu): highly parallel neural computer.
......
...@@ -118,7 +118,7 @@ class AdversarialCrypto(object): ...@@ -118,7 +118,7 @@ class AdversarialCrypto(object):
def model(self, collection, message, key=None): def model(self, collection, message, key=None):
"""The model for Alice, Bob, and Eve. If key=None, the first FC layer """The model for Alice, Bob, and Eve. If key=None, the first FC layer
takes only the Key as inputs. Otherwise, it uses both the key takes only the message as inputs. Otherwise, it uses both the key
and the message. and the message.
Args: Args:
......
...@@ -145,7 +145,7 @@ def visit_count_fc(visit_count, last_visit, embed_neurons, wt_decay, fc_dropout) ...@@ -145,7 +145,7 @@ def visit_count_fc(visit_count, last_visit, embed_neurons, wt_decay, fc_dropout)
on_value=10., off_value=0.) on_value=10., off_value=0.)
last_visit = tf.one_hot(last_visit, depth=16, axis=1, dtype=tf.float32, last_visit = tf.one_hot(last_visit, depth=16, axis=1, dtype=tf.float32,
on_value=10., off_value=0.) on_value=10., off_value=0.)
f = tf.concat_v2([visit_count, last_visit], 1) f = tf.concat([visit_count, last_visit], 1)
x, _ = tf_utils.fc_network( x, _ = tf_utils.fc_network(
f, neurons=embed_neurons, wt_decay=wt_decay, name='visit_count_embed', f, neurons=embed_neurons, wt_decay=wt_decay, name='visit_count_embed',
offset=0, batch_norm_param=None, dropout_ratio=fc_dropout, offset=0, batch_norm_param=None, dropout_ratio=fc_dropout,
...@@ -201,7 +201,7 @@ def combine_setup(name, combine_type, embed_img, embed_goal, num_img_neuorons=No ...@@ -201,7 +201,7 @@ def combine_setup(name, combine_type, embed_img, embed_goal, num_img_neuorons=No
def preprocess_egomotion(locs, thetas): def preprocess_egomotion(locs, thetas):
with tf.name_scope('pre_ego'): with tf.name_scope('pre_ego'):
pre_ego = tf.concat_v2([locs, tf.sin(thetas), tf.cos(thetas)], 2) pre_ego = tf.concat([locs, tf.sin(thetas), tf.cos(thetas)], 2)
sh = pre_ego.get_shape().as_list() sh = pre_ego.get_shape().as_list()
pre_ego = tf.reshape(pre_ego, [-1, sh[-1]]) pre_ego = tf.reshape(pre_ego, [-1, sh[-1]])
return pre_ego return pre_ego
......
# LFADS - Latent Factor Analysis via Dynamical Systems
This code implements the model from the paper "[LFADS - Latent Factor Analysis via Dynamical Systems](http://biorxiv.org/content/early/2017/06/20/152884)". It is a sequential variational auto-encoder designed specifically for investigating neuroscience data, but can be applied widely to any time series data. In an unsupervised setting, LFADS is able to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observation model is a loss on a probability distribution, so when LFADS processes a dataset, a denoised version of the dataset is also created. For example, if the dataset is raw spike counts, then under the negative log-likeihood loss under a Poisson distribution, the denoised data would be the inferred Poisson rates.
## Prerequisites
The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.1 ([install](http://tflearn.org/installation/)) -
there is an incompatibility with LFADS and TF v1.2, which we are in the
process of resolving
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py))
## Getting started
Before starting, run the following:
<pre>
$ export PYTHONPATH=$PYTHONPATH:/<b>path/to/your/directory</b>/lfads/
</pre>
where "path/to/your/directory" is replaced with the path to the LFADS repository (you can get this path by using the `pwd` command). This allows the nested directories to access modules from their parent directory.
## Generate synthetic data
In order to generate the synthetic datasets first, from the top-level lfads directory, run:
```sh
$ cd synth_data
$ ./run_generate_synth_data.sh
$ cd ..
```
These synthetic datasets are provided 1. to gain insight into how the LFADS algorithm operates, and 2. to give reasonable starting points for analyses you might be interested for your own data.
## Train an LFADS model
Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 8) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
```sh
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5)
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=poisson \
--do_causal_controller=false \
--batch_size=128 \
--learning_rate_init=0.01 \
--learning_rate_stop=1e-05 \
--learning_rate_decay_factor=0.95 \
--learning_rate_n_to_compare=6 \
--do_reset_learning_rate=false \
--keep_prob=0.95 \
--con_dim=128 \
--gen_dim=200 \
--ci_enc_dim=128 \
--ic_dim=64 \
--ic_enc_dim=128 \
--ic_prior_var_min=0.1 \
--gen_cell_input_weight_scale=1.0 \
--cell_weight_scale=1.0 \
--do_feed_factors_to_controller=true \
--kl_start_step=0 \
--kl_increase_steps=2000 \
--kl_ic_weight=1.0 \
--l2_con_scale=0.0 \
--l2_gen_scale=2000.0 \
--l2_start_step=0 \
--l2_increase_steps=2000 \
--ic_prior_var_scale=0.1 \
--ic_post_var_min=0.0001 \
--kl_co_weight=1.0 \
--prior_ar_nvar=0.1 \
--cell_clip_value=5.0 \
--max_ckpt_to_keep_lve=5 \
--do_train_prior_ar_atau=true \
--co_prior_var_scale=0.1 \
--csv_log=fitlog \
--feedback_factors_or_rates=factors \
--do_train_prior_ar_nvar=true \
--max_grad_norm=200.0 \
--device=gpu:0 \
--num_steps_for_gen_ic=100000000 \
--ps_nexamples_to_process=100000000 \
--checkpoint_name=lfads_vae \
--temporal_spike_jitter_width=0 \
--checkpoint_pb_load_name=checkpoint \
--inject_ext_input_to_gen=false \
--co_mean_corr_scale=0.0 \
--gen_cell_rec_weight_scale=1.0 \
--max_ckpt_to_keep=5 \
--output_filename_stem="" \
--ic_prior_var_max=0.1 \
--prior_ar_atau=10.0 \
--do_train_io_only=false
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20
# Run LFADS on multi-session RNN data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_multisession \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
--factors_dim=10
# Run LFADS on integration to bound model data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=itb_rnn \
--lfads_save_dir=/tmp/lfads_itb_rnn \
--co_dim=1 \
--factors_dim=20 \
--controller_input_lag=0
# Run LFADS on chaotic RNN data with labels
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnns_labeled \
--lfads_save_dir=/tmp/lfads_chaotic_rnns_labeled \
--co_dim=0 \
--factors_dim=20 \
--controller_input_lag=0 \
--ext_input_dim=1
```
**Tip**: If you are running LFADS on GPU and would like to run more than one model concurrently, set the `--allow_gpu_growth=True` flag on each job, otherwise one model will take up the entire GPU for performance purposes. Also, one needs to install the TensorFlow libraries with GPU support.
## Visualize a training model
To visualize training curves and various other metrics while training and LFADS model, run the following command on your model directory. To launch a tensorboard on the chaotic RNN data with input pulses, for example:
```sh
tensorboard --logdir=/tmp/lfads_chaotic_rnn_inputs_g2p5
```
## Evaluate a trained model
Once your model is finished training, there are multiple ways you can evaluate
it. Below are some sample commands to evaluate an LFADS model trained on the
chaotic rnn data with input pulses (g = 2.5). The key differences here are
setting the `--kind` flag to the appropriate mode, as well as the
`--checkpoint_pb_load_name` flag to `checkpoint_lve` and the `--batch_size` flag
(if you'd like to make it larger or smaller). All other flags should be the
same as used in training, so that the same model architecture is built.
```sh
# Take samples from posterior then average (denoising operation)
$ python run_lfads.py --kind=posterior_sample_and_average \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=1024 \
--checkpoint_pb_load_name=checkpoint_lve
# Sample from prior (generation of completely new samples)
$ python run_lfads.py --kind=prior_sample \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=50 \
--checkpoint_pb_load_name=checkpoint_lve
# Write down model parameters
$ python run_lfads.py --kind=write_model_params \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--checkpoint_pb_load_name=checkpoint_lve
```
## Contact
File any issues with the [issue tracker](https://github.com/tensorflow/models/issues). For any questions or problems, this code is maintained by [@sussillo](https://github.com/sussillo) and [@jazcollins](https://github.com/jazcollins).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import tensorflow as tf
from utils import linear, log_sum_exp
class Poisson(object):
"""Poisson distributon
Computes the log probability under the model.
"""
def __init__(self, log_rates):
""" Create Poisson distributions with log_rates parameters.
Args:
log_rates: a tensor-like list of log rates underlying the Poisson dist.
"""
self.logr = log_rates
def logp(self, bin_counts):
"""Compute the log probability for the counts in the bin, under the model.
Args:
bin_counts: array-like integer counts
Returns:
The log-probability under the Poisson models for each element of
bin_counts.
"""
k = tf.to_float(bin_counts)
# log poisson(k, r) = log(r^k * e^(-r) / k!) = k log(r) - r - log k!
# log poisson(k, r=exp(x)) = k * x - exp(x) - lgamma(k + 1)
return k * self.logr - tf.exp(self.logr) - tf.lgamma(k + 1)
def diag_gaussian_log_likelihood(z, mu=0.0, logvar=0.0):
"""Log-likelihood under a Gaussian distribution with diagonal covariance.
Returns the log-likelihood for each dimension. One should sum the
results for the log-likelihood under the full multidimensional model.
Args:
z: The value to compute the log-likelihood.
mu: The mean of the Gaussian
logvar: The log variance of the Gaussian.
Returns:
The log-likelihood under the Gaussian model.
"""
return -0.5 * (logvar + np.log(2*np.pi) + \
tf.square((z-mu)/tf.exp(0.5*logvar)))
def gaussian_pos_log_likelihood(unused_mean, logvar, noise):
"""Gaussian log-likelihood function for a posterior in VAE
Note: This function is specialized for a posterior distribution, that has the
form of z = mean + sigma * noise.
Args:
unused_mean: ignore
logvar: The log variance of the distribution
noise: The noise used in the sampling of the posterior.
Returns:
The log-likelihood under the Gaussian model.
"""
# ln N(z; mean, sigma) = - ln(sigma) - 0.5 ln 2pi - noise^2 / 2
return - 0.5 * (logvar + np.log(2 * np.pi) + tf.square(noise))
class Gaussian(object):
"""Base class for Gaussian distribution classes."""
pass
class DiagonalGaussian(Gaussian):
"""Diagonal Gaussian with different constant mean and variances in each
dimension.
"""
def __init__(self, batch_size, z_size, mean, logvar):
"""Create a diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
mean: The N-D mean of the distribution.
logvar: The N-D log variance of the diagonal distribution.
"""
size__xz = [None, z_size]
self.mean = mean # bxn already
self.logvar = logvar # bxn already
self.noise = noise = tf.random_normal(tf.shape(logvar))
self.sample = mean + tf.exp(0.5 * logvar) * noise
mean.set_shape(size__xz)
logvar.set_shape(size__xz)
self.sample.set_shape(size__xz)
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample:
return gaussian_pos_log_likelihood(self.mean, self.logvar, self.noise)
return diag_gaussian_log_likelihood(z, self.mean, self.logvar)
class LearnableDiagonalGaussian(Gaussian):
"""Diagonal Gaussian whose mean and variance are learned parameters."""
def __init__(self, batch_size, z_size, name, mean_init=0.0,
var_init=1.0, var_min=0.0, var_max=1000000.0):
"""Create a learnable diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
name: prefix name for the mean and log TF variables.
mean_init (optional): The N-D mean initialization of the distribution.
var_init (optional): The N-D variance initialization of the diagonal
distribution.
var_min (optional): The minimum value the learned variance can take in any
dimension.
var_max (optional): The maximum value the learned variance can take in any
dimension.
"""
size_1xn = [1, z_size]
size__xn = [None, z_size]
size_bx1 = tf.stack([batch_size, 1])
assert var_init > 0.0, "Problems"
assert var_max >= var_min, "Problems"
assert var_init >= var_min, "Problems"
assert var_max >= var_init, "Problems"
z_mean_1xn = tf.get_variable(name=name+"/mean", shape=size_1xn,
initializer=tf.constant_initializer(mean_init))
self.mean_bxn = mean_bxn = tf.tile(z_mean_1xn, size_bx1)
mean_bxn.set_shape(size__xn) # tile loses shape
log_var_init = np.log(var_init)
if var_max > var_min:
var_is_trainable = True
else:
var_is_trainable = False
z_logvar_1xn = \
tf.get_variable(name=(name+"/logvar"), shape=size_1xn,
initializer=tf.constant_initializer(log_var_init),
trainable=var_is_trainable)
if var_is_trainable:
z_logit_var_1xn = tf.exp(z_logvar_1xn)
z_var_1xn = tf.nn.sigmoid(z_logit_var_1xn)*(var_max-var_min) + var_min
z_logvar_1xn = tf.log(z_var_1xn)
logvar_bxn = tf.tile(z_logvar_1xn, size_bx1)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(tf.shape(logvar_bxn))
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn, self.logvar_bxn,
self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class DiagonalGaussianFromInput(Gaussian):
"""Diagonal Gaussian whose mean and variance are conditioned on other
variables.
Note: the parameters to convert from input to the learned mean and log
variance are held in this class.
"""
def __init__(self, x_bxu, z_size, name, var_min=0.0):
"""Create an input dependent diagonal Gaussian distribution.
Args:
x: The input tensor from which the mean and variance are computed,
via a linear transformation of x. I.e.
mu = Wx + b, log(var) = Mx + c
z_size: The size of the distribution.
name: The name to prefix to learned variables.
var_min (optional): Minimal variance allowed. This is an additional
way to control the amount of information getting through the stochastic
layer.
"""
size_bxn = tf.stack([tf.shape(x_bxu)[0], z_size])
self.mean_bxn = mean_bxn = linear(x_bxu, z_size, name=(name+"/mean"))
logvar_bxn = linear(x_bxu, z_size, name=(name+"/logvar"))
if var_min > 0.0:
logvar_bxn = tf.log(tf.exp(logvar_bxn) + var_min)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(size_bxn)
self.noise_bxn.set_shape([None, z_size])
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn,
self.logvar_bxn, self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class GaussianProcess:
"""Base class for Gaussian processes."""
pass
class LearnableAutoRegressive1Prior(GaussianProcess):
"""AR(1) model where autocorrelation and process variance are learned
parameters. Assumed zero mean.
"""
def __init__(self, batch_size, z_size,
autocorrelation_taus, noise_variances,
do_train_prior_ar_atau, do_train_prior_ar_nvar,
num_steps, name):
"""Create a learnable autoregressive (1) process.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
autocorrelation_taus: The auto correlation time constant of the AR(1)
process.
A value of 0 is uncorrelated gaussian noise.
noise_variances: The variance of the additive noise, *not* the process
variance.
do_train_prior_ar_atau: Train or leave as constant, the autocorrelation?
do_train_prior_ar_nvar: Train or leave as constant, the noise variance?
num_steps: Number of steps to run the process.
name: The name to prefix to learned TF variables.
"""
# Note the use of the plural in all of these quantities. This is intended
# to mark that even though a sample z_t from the posterior is thought of a
# single sample of a multidimensional gaussian, the prior is actually
# thought of as U AR(1) processes, where U is the dimension of the inferred
# input.
size_bx1 = tf.stack([batch_size, 1])
size__xu = [None, z_size]
# process variance, the variance at time t over all instantiations of AR(1)
# with these parameters.
log_evar_inits_1xu = tf.expand_dims(tf.log(noise_variances), 0)
self.logevars_1xu = logevars_1xu = \
tf.Variable(log_evar_inits_1xu, name=name+"/logevars", dtype=tf.float32,
trainable=do_train_prior_ar_nvar)
self.logevars_bxu = logevars_bxu = tf.tile(logevars_1xu, size_bx1)
logevars_bxu.set_shape(size__xu) # tile loses shape
# \tau, which is the autocorrelation time constant of the AR(1) process
log_atau_inits_1xu = tf.expand_dims(tf.log(autocorrelation_taus), 0)
self.logataus_1xu = logataus_1xu = \
tf.Variable(log_atau_inits_1xu, name=name+"/logatau", dtype=tf.float32,
trainable=do_train_prior_ar_atau)
# phi in x_t = \mu + phi x_tm1 + \eps
# phi = exp(-1/tau)
# phi = exp(-1/exp(logtau))
# phi = exp(-exp(-logtau))
phis_1xu = tf.exp(-tf.exp(-logataus_1xu))
self.phis_bxu = phis_bxu = tf.tile(phis_1xu, size_bx1)
phis_bxu.set_shape(size__xu)
# process noise
# pvar = evar / (1- phi^2)
# logpvar = log ( exp(logevar) / (1 - phi^2) )
# logpvar = logevar - log(1-phi^2)
# logpvar = logevar - (log(1-phi) + log(1+phi))
self.logpvars_1xu = \
logevars_1xu - tf.log(1.0-phis_1xu) - tf.log(1.0+phis_1xu)
self.logpvars_bxu = logpvars_bxu = tf.tile(self.logpvars_1xu, size_bx1)
logpvars_bxu.set_shape(size__xu)
# process mean (zero but included in for completeness)
self.pmeans_bxu = pmeans_bxu = tf.zeros_like(phis_bxu)
# For sampling from the prior during de-novo generation.
self.means_t = means_t = [None] * num_steps
self.logvars_t = logvars_t = [None] * num_steps
self.samples_t = samples_t = [None] * num_steps
self.gaussians_t = gaussians_t = [None] * num_steps
sample_bxu = tf.zeros_like(phis_bxu)
for t in range(num_steps):
# process variance used here to make process completely stationary
if t == 0:
logvar_pt_bxu = self.logpvars_bxu
else:
logvar_pt_bxu = self.logevars_bxu
z_mean_pt_bxu = pmeans_bxu + phis_bxu * sample_bxu
gaussians_t[t] = DiagonalGaussian(batch_size, z_size,
mean=z_mean_pt_bxu,
logvar=logvar_pt_bxu)
sample_bxu = gaussians_t[t].sample
samples_t[t] = sample_bxu
logvars_t[t] = logvar_pt_bxu
means_t[t] = z_mean_pt_bxu
def logp_t(self, z_t_bxu, z_tm1_bxu=None):
"""Compute the log-likelihood under the distribution for a given time t,
not the whole sequence.
Args:
z_t_bxu: sample to compute likelihood for at time t.
z_tm1_bxu (optional): sample condition probability of z_t upon.
Returns:
The likelihood of p_t under the model at time t. i.e.
p(z_t|z_tm1) = N(z_tm1 * phis, eps^2)
"""
if z_tm1_bxu is None:
return diag_gaussian_log_likelihood(z_t_bxu, self.pmeans_bxu,
self.logpvars_bxu)
else:
means_t_bxu = self.pmeans_bxu + self.phis_bxu * z_tm1_bxu
logp_tgtm1_bxu = diag_gaussian_log_likelihood(z_t_bxu,
means_t_bxu,
self.logevars_bxu)
return logp_tgtm1_bxu
class KLCost_GaussianGaussian(object):
"""log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian prior. See
eqn 10 and Appendix B in VAE for latter term,
http://arxiv.org/abs/1312.6114
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, zs, prior_zs):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
E_q[ln p(z_i | z_{i+1}) / q(z_i | x)
\int q(z) ln p(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_p^2) + \
sigma_q^2 / sigma_p^2 + (mean_p - mean_q)^2 / sigma_p^2)
\int q(z) ln q(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_q^2) + 1)
Args:
zs: posterior z ~ q(z|x)
prior_zs: prior zs
"""
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
kl_b = 0.0
for z, prior_z in zip(zs, prior_zs):
assert isinstance(z, Gaussian)
assert isinstance(prior_z, Gaussian)
# ln(2pi) terms cancel
kl_b += 0.5 * tf.reduce_sum(
prior_z.logvar - z.logvar
+ tf.exp(z.logvar - prior_z.logvar)
+ tf.square((z.mean - prior_z.mean) / tf.exp(0.5 * prior_z.logvar))
- 1.0, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
class KLCost_GaussianGaussianProcessSampled(object):
""" log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian process
prior via sampling.
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, post_zs, prior_z_process):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
Args:
post_zs: posterior z ~ q(z|x)
prior_z_process: prior AR(1) process
"""
assert len(post_zs) > 1, "GP is for time, need more than 1 time step."
assert isinstance(prior_z_process, GaussianProcess), "Must use GP."
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
z0_bxu = post_zs[0].sample
logq_bxu = post_zs[0].logp(z0_bxu)
logp_bxu = prior_z_process.logp_t(z0_bxu)
z_tm1_bxu = z0_bxu
for z_t in post_zs[1:]:
# posterior is independent in time, prior is not
z_t_bxu = z_t.sample
logq_bxu += z_t.logp(z_t_bxu)
logp_bxu += prior_z_process.logp_t(z_t_bxu, z_tm1_bxu)
z_tm1 = z_t_bxu
kl_bxu = logq_bxu - logp_bxu
kl_b = tf.reduce_sum(kl_bxu, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
This diff is collapsed.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
def _plot_item(W, name, full_name, nspaces):
plt.figure()
if W.shape == ():
print(name, ": ", W)
elif W.shape[0] == 1:
plt.stem(W.T)
plt.title(full_name)
elif W.shape[1] == 1:
plt.stem(W)
plt.title(full_name)
else:
plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
plt.colorbar()
plt.title(full_name)
def all_plot(d, full_name="", exclude="", nspaces=0):
"""Recursively plot all the LFADS model parameters in the nested
dictionary."""
for k, v in d.iteritems():
this_name = full_name+"/"+k
if isinstance(v, dict):
all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
else:
if exclude == "" or exclude not in this_name:
_plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)
def plot_priors():
g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
g0s_prior_var_bxn = train_modelvals['prior_g0_var']
g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
g0s_post_var_bxn = train_modelvals['posterior_g0_var']
plt.figure(figsize=(10,4), tight_layout=True);
plt.subplot(1,2,1)
plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Mean Values')
plt.subplot(1,2,2)
plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Log Variance Values')
plt.figure(figsize=(10,10), tight_layout=True)
plt.subplot(2,2,1)
plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 means')
plt.subplot(2,2,2)
plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 means');
plt.subplot(2,2,3)
plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 variance Values')
plt.subplot(2,2,4)
plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 variance Values')
plt.figure(figsize=(10,5))
plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
plt.title('Log standard deviation of h0 means');
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
color='r', title=None):
if bidx is None:
vals_txn = np.mean(vals_bxtxn, axis=0)
else:
vals_txn = vals_bxtxn[bidx,:,:]
T, N = vals_txn.shape
if n_to_plot > N:
n_to_plot = N
plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
color=color, lw=1.0)
plt.axis('tight')
if title:
plt.title(title)
def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
truth_bxtxn=None, bidx=None, output_dist="poisson",
conversion_factor=1.0, subplot_cidx=0,
col_title=None):
n_to_plot = 10
scale = 1.0
nrows = 7
plt.subplot(nrows,2,1+subplot_cidx)
if output_dist == 'poisson':
rates = means = conversion_factor * model_vals['output_dist_params']
plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " rates (LFADS - red, Truth - black)")
elif output_dist == 'gaussian':
means_vars = model_vals['output_dist_params']
means, vars = np.split(means_vars,2, axis=2) # bxtxn
stds = np.sqrt(vars)
plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " means (LFADS - red, Truth - black)")
plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
else:
assert 'NIY'
if truth_bxtxn is not None:
plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
scale=scale)
input_title = ""
if "controller_outputs" in model_vals.keys():
input_title += " Controller Output"
plt.subplot(nrows,2,3+subplot_cidx)
u_t = model_vals['controller_outputs'][0:-1]
plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
title=col_title + input_title)
if ext_input_bxtxi is not None:
input_title += " External Input"
plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
scale=scale, title=col_title + input_title)
plt.subplot(nrows,2,5+subplot_cidx)
plot_time_series(means, bidx,
n_to_plot=n_to_plot, scale=1.0,
title=col_title + " Spikes (LFADS - red, Spikes - black)")
plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)
plt.subplot(nrows,2,7+subplot_cidx)
plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
scale=2.0, title=col_title + " Factors")
plt.subplot(nrows,2,9+subplot_cidx)
plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
color='g', scale=1.0, title=col_title + " Generator State")
if bidx is not None:
data_nxt = data_bxtxn[bidx,:,:].T
params_nxt = model_vals['output_dist_params'][bidx,:,:].T
else:
data_nxt = np.mean(data_bxtxn, axis=0).T
params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
if output_dist == 'poisson':
means_nxt = params_nxt
elif output_dist == 'gaussian': # (means+vars) x time
means_nxt = np.vsplit(params_nxt,2)[0] # get means
else:
assert "NIY"
plt.subplot(nrows,2,11+subplot_cidx)
plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Data')
plt.subplot(nrows,2,13+subplot_cidx)
plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Means')
def plot_lfads(train_bxtxd, train_model_vals,
train_ext_input_bxtxi=None, train_truth_bxtxd=None,
valid_bxtxd=None, valid_model_vals=None,
valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
bidx=None, cf=1.0, output_dist='poisson'):
# Plotting
f = plt.figure(figsize=(18,20), tight_layout=True)
plot_lfads_timeseries(train_bxtxd, train_model_vals,
train_ext_input_bxtxi,
truth_bxtxn=train_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist, col_title='Train')
plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
valid_ext_input_bxtxi,
truth_bxtxn=valid_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist,
subplot_cidx=1, col_title='Valid')
# Convert from figure to an numpy array width x height x 3 (last for RGB)
f.canvas.draw()
data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data_wxhx3
This diff is collapsed.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf # used for flags here
from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
matplotlib.rcParams['image.interpolation'] = 'nearest'
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 40,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("input_magnitude", 20.0,
"For the input case, what is the value of the input?")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
# Note that with N small, (as it is 25 above), the finite size effects
# will have pretty dramatic effects on the dynamics of the random RNN.
# If you want more complex dynamics, you'll have to run the script a
# lot, or increase N (or g).
# Getting hard vs. easy data can be a little stochastic, so we set the seed.
# Pull out some commonly used parameters.
# These are user parameters (configuration)
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
S = FLAGS.S
input_magnitude = FLAGS.input_magnitude
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
# S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations.
ndatasets = N/S # ok if rounded down
train_percentage = FLAGS.train_percentage
ntime_steps = int(T / FLAGS.dt)
# End of user parameters
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper.
if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nspikifications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nspikifications should be integral number.'
# Initial condition generation, and condition label generation. This
# happens outside of the dataset loop, so that all datasets have the
# same conditions, which is similar to a neurophys setup.
condition_number = 0
x0s = []
condition_labels = []
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) # replicate x0 nspikifications times
# replicate the condition label nspikifications times
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
# Containers for storing data across data.
datasets = {}
for n in range(ndatasets):
print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all
# spikifications this allows the random state for rate generation to be
# independent of n_spikifications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N:
dataset_name += '_n' + str(n+1)
# Sample neuron subsets. The assumption is the PC axes of the RNN
# are not unit aligned, so sampling units is adequate to sample all
# the high-variance PCs.
P_sxn = np.eye(S,N)
for m in range(n):
P_sxn = np.roll(P_sxn, S, axis=1)
if input_magnitude > 0.0:
# time of "hits" randomly chosen between [1/4 and 3/4] of total time
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
else:
input_times = None
rates, x0s, inputs = \
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude,
input_times=input_times)
spikes = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
# Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = \
split_list_by_inds(spikes, train_inds, valid_inds)
input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \
split_list_by_inds(condition_labels, train_inds, valid_inds)
input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, spikes, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the spike
# trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : rnn['dt'],
'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train,
'input_times_valid' : input_times_valid,
'P_sxn' : P_sxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn['conversion_factor']}
datasets[dataset_name] = data
if S < N:
# Note that this isn't necessary for this synthetic example, but
# it's useful to see how the input factor matrices were initialized
# for actual neurophysiology data.
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
# Write out the datasets.
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf
from utils import write_datasets
from synthetic_data_utils import normalize_rates
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "itb_rnn",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0,
"Map 1.0 of RNN to a spikes per second")
flags.DEFINE_float("u_std", 0.25,
"Std dev of input to integration to bound model")
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
"""Path to directory with checkpoints of model
trained on integration to bound task. Currently this
is a placeholder which tells the code to grab the
checkpoint that is provided with the code
(in /trained_itb/..). If you have your own checkpoint
you would like to restore, you would point it to
that path.""")
FLAGS = flags.FLAGS
class IntegrationToBoundModel:
def __init__(self, N):
scale = 0.8 / float(N**0.5)
self.N = N
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
self.b_1xn = tf.Variable(tf.zeros([1, N]))
self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
self.bro_o = tf.Variable(tf.zeros([1]))
def call(self, h_tm1_bxn, u_bx1):
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
h_t_bxn = tf.nn.tanh(act_t_bxn)
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
return z_t, h_t_bxn
def get_data_batch(batch_size, T, rng, u_std):
u_bxt = rng.randn(batch_size, T) * u_std
running_sum_b = np.zeros([batch_size])
labels_bxt = np.zeros([batch_size, T])
for t in xrange(T):
running_sum_b += u_bxt[:, t]
labels_bxt[:, t] += running_sum_b
labels_bxt = np.clip(labels_bxt, -1, 1)
return u_bxt, labels_bxt
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial
model = IntegrationToBoundModel(N)
inputs_ph_t = [tf.placeholder(tf.float32,
shape=[None, 1]) for _ in range(ntimesteps)]
state = tf.zeros([batch_size, N])
saver = tf.train.Saver()
P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections
# unroll RNN for T timesteps
outputs_t = []
states_t = []
for inp in inputs_ph_t:
output, state = model.call(state, inp)
outputs_t.append(output)
states_t.append(state)
with tf.Session() as sess:
# restore the latest model ckpt
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
dir_path = os.path.dirname(os.path.realpath(__file__))
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
else:
model_checkpoint_path = FLAGS.checkpoint_path
try:
saver.restore(sess, model_checkpoint_path)
print ('Model restored from', model_checkpoint_path)
except:
assert False, ("No checkpoints to restore from, is the path %s correct?"
%model_checkpoint_path)
# generate data for trials
data_e = []
u_e = []
outs_e = []
for c in range(C):
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
feed_dict = {}
for t in xrange(ntimesteps):
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
feed_dict=feed_dict)
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications):
data_e.append(r_sxt)
u_e.append(u_1xt)
outs_e.append(outputs_t_bxn)
truth_data_e = normalize_rates(data_e, E, N)
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds,
valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
train_inds,
valid_inds)
data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
train_inds,
valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)
# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
train_inds,
valid_inds)
train_outputs_u = np.array(train_outputs_u)
valid_outputs_u = np.array(valid_outputs_u)
data = { 'train_truth': data_train_truth,
'valid_truth': data_valid_truth,
'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate,
'train_inputs_u': train_inputs_u,
'valid_inputs_u': valid_inputs_u,
'train_outputs_u': train_outputs_u,
'valid_outputs_u': valid_outputs_u,
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import numpy as np
from synthetic_data_utils import generate_data, generate_rnn
from synthetic_data_utils import get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import tensorflow as tf
from utils import write_datasets
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "conditioned_rnn_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
np.random.RandomState(seed=FLAGS.synth_data_seed+2)]
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
nspikifications = FLAGS.nspikifications
E = nspikifications * C
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnns = [rnn_a, rnn_b]
# pick which RNN is used on each trial
rnn_to_use = rng.randint(2, size=E)
ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1)
ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper
x0s = []
condition_labels = []
condition_number = 0
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications))
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
P_nxn = rng.randn(N, N) / np.sqrt(N)
# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])
rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])
# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
if rnn_to_use[trial] == 0:
rates.append(rates_a[trial])
spikes.append(spikes_a[trial])
else:
rates.append(rates_b[trial])
spikes.append(spikes_b[trial])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(
ext_input, train_inds, valid_inds)
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
# add train_ext_input and valid_ext input
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn_a['conversion_factor']}
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
Args:
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
Returns:
the dictionary of these parameters, plus some others.
"""
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
input_times=None):
""" Generates data from an randomly initialized RNN.
Args:
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
Returns:
A list of length E of NxT tensors of the network being run.
"""
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt = np.dot(P_sxn, r_nxt)
inputs_e.append(u_uxt)
data_e.append(r_sxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_data_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
spikifies_data_e = []
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(np.int)
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
spikes_e.append(data_s)
return spikes_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications:
valid_inds.append(i)
else:
train_inds.append(i)
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
inds2.
Args:
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
"""
if data is None or len(data) == 0:
return [], []
else:
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
Args:
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
Returns:
a numpy 3-tensor with dimensions a x c x b
"""
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
system.
Args:
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
Returns:
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
"""
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
avg_data_all[name][cname].T
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.figure()
plt.plot(np.log10(evals_n), '-x')
plt.figure()
plt.subplot(311)
plt.imshow(all_data_pca_pxtc)
plt.colorbar()
plt.subplot(312)
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
plt.colorbar()
plt.subplot(313)
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
plt.colorbar()
import pdb
pdb.set_trace()
return datasets
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import json
import numpy as np
import tensorflow as tf
def log_sum_exp(x_k):
"""Computes log \sum exp in a numerically stable way.
log ( sum_i exp(x_i) )
log ( sum_i exp(x_i - m + m) ), with m = max(x_i)
log ( sum_i exp(x_i - m)*exp(m) )
log ( sum_i exp(x_i - m) + m
Args:
x_k - k -dimensional list of arguments to log_sum_exp.
Returns:
log_sum_exp of the arguments.
"""
m = tf.reduce_max(x_k)
x1_k = x_k - m
u_k = tf.exp(x1_k)
z = tf.reduce_sum(u_k)
return tf.log(z) + m
def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
x: input The tensor to tranformation.
out_size: The integer size of non-batch output dimension.
do_bias (optional): Add a learnable bias vector to the operation.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the tensorflow op that yields y.
"""
in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
# Sometimes linear layers are nothing more than size adapters.
return tf.identity(x, name=(wname+'_ident'))
W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha,
normalized=normalized, name=name, collections=collections)
if do_bias:
return tf.matmul(x, W) + b
else:
return tf.matmul(x, W)
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
identity_if_possible=False, normalized=False,
name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation.
mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the pair (W, b).
"""
if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
raise ValueError(
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
return (tf.constant(np.eye(in_size).astype(np.float32)),
tf.zeros(in_size))
# Note the use of get_variable vs. tf.Variable. this is because get_variable
# does not allow the initialization of the variable with a value.
if normalized:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
b_collections += collections
bname = (name + "/b") if name else "/b"
b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(),
collections=b_collections)
else:
b = None
return (w, b)
def write_data(data_fname, data_dict, use_json=False, compression=None):
"""Write data in HD5F format.
Args:
data_fname: The filename of teh file in which to write the data.
data_dict: The dictionary of data to write. The keys are strings
and the values are numpy arrays.
use_json (optional): human readable format for simple items
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
dir_name = os.path.dirname(data_fname)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
if use_json:
the_file = open(data_fname,'w')
json.dump(data_dict, the_file)
the_file.close()
else:
try:
with h5py.File(data_fname, 'w') as hf:
for k, v in data_dict.items():
clean_k = k.replace('/', '_')
if clean_k is not k:
print('Warning: saving variable with name: ', k, ' as ', clean_k)
else:
print('Saving variable with name: ', clean_k)
hf.create_dataset(clean_k, data=v, compression=compression)
except IOError:
print("Cannot open %s for writing.", data_fname)
raise
def read_data(data_fname):
""" Read saved data in HDF5 format.
Args:
data_fname: The filename of the file from which to read the data.
Returns:
A dictionary whose keys will vary depending on dataset (but should
always contain the keys 'train_data' and 'valid_data') and whose
values are numpy arrays.
"""
try:
with h5py.File(data_fname, 'r') as hf:
data_dict = {k: np.array(v) for k, v in hf.items()}
return data_dict
except IOError:
print("Cannot open %s for reading." % data_fname)
raise
def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None):
"""Write datasets in HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
dataset_dict: The dictionary of datasets. The keys are strings
and the values data dictionaries (str -> numpy arrays) associations.
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
full_name_stem = os.path.join(data_path, data_fname_stem)
for s, data_dict in dataset_dict.items():
write_data(full_name_stem + "_" + s, data_dict, compression=compression)
def read_datasets(data_path, data_fname_stem):
"""Read dataset sin HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
"""
dataset_dict = {}
fnames = os.listdir(data_path)
print ('loading data from ' + data_path + ' with stem ' + data_fname_stem)
for fname in fnames:
if fname.startswith(data_fname_stem):
data_dict = read_data(os.path.join(data_path,fname))
idx = len(data_fname_stem) + 1
key = fname[idx:]
data_dict['data_dim'] = data_dict['train_data'].shape[2]
data_dict['num_steps'] = data_dict['train_data'].shape[1]
dataset_dict[key] = data_dict
if len(dataset_dict) == 0:
raise ValueError("Failed to load any datasets, are you sure that the "
"'--data_dir' and '--data_filename_stem' flag values "
"are correct?")
print (str(len(dataset_dict)) + ' datasets loaded')
return dataset_dict
# NUMPY utility functions
def list_t_bxn_to_list_b_txn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors of length B list of TxN numpy
tensors.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
The length B list of TxN numpy tensors.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_b_txn = []
for b in range(B):
values_pb_txn = np.zeros([T,N])
for t in range(T):
values_pb_txn[t,:] = values_t_bxn[t][b,:]
values_b_txn.append(values_pb_txn)
return values_b_txn
def list_t_bxn_to_tensor_bxtxn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors to single numpy tensor with
shape BxTxN.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
values_bxtxn: The BxTxN numpy tensor.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_bxtxn = np.zeros([B,T,N])
for t in range(T):
values_bxtxn[:,t,:] = values_t_bxn[t]
return values_bxtxn
def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn):
"""Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors
with shape BxT.
Args:
tensor_bxtxn: The BxTxN numpy tensor.
Returns:
A length T list of numpy tensors with shape BxT.
"""
values_t_bxn = []
B, T, N = tensor_bxtxn.shape
for t in range(T):
values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:]))
return values_t_bxn
def flatten(list_of_lists):
"""Takes a list of lists and returns a list of the elements.
Args:
list_of_lists: List of lists.
Returns:
flat_list: Flattened list.
flat_list_idxs: Flattened list indices.
"""
flat_list = []
flat_list_idxs = []
start_idx = 0
for item in list_of_lists:
if isinstance(item, list):
flat_list += item
l = len(item)
idxs = range(start_idx, start_idx+l)
start_idx = start_idx+l
else: # a value
flat_list.append(item)
idxs = [start_idx]
start_idx += 1
flat_list_idxs.append(idxs)
return flat_list, flat_list_idxs
...@@ -111,9 +111,9 @@ def _build_regularizer(regularizer): ...@@ -111,9 +111,9 @@ def _build_regularizer(regularizer):
""" """
regularizer_oneof = regularizer.WhichOneof('regularizer_oneof') regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
if regularizer_oneof == 'l1_regularizer': if regularizer_oneof == 'l1_regularizer':
return slim.l1_regularizer(scale=regularizer.l1_regularizer.weight) return slim.l1_regularizer(scale=float(regularizer.l1_regularizer.weight))
if regularizer_oneof == 'l2_regularizer': if regularizer_oneof == 'l2_regularizer':
return slim.l2_regularizer(scale=regularizer.l2_regularizer.weight) return slim.l2_regularizer(scale=float(regularizer.l2_regularizer.weight))
raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof)) raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))
......
...@@ -60,6 +60,6 @@ def build(input_reader_config): ...@@ -60,6 +60,6 @@ def build(input_reader_config):
capacity=input_reader_config.queue_capacity, capacity=input_reader_config.queue_capacity,
min_after_dequeue=input_reader_config.min_after_dequeue) min_after_dequeue=input_reader_config.min_after_dequeue)
return tf_example_decoder.TfExampleDecoder().Decode(string_tensor) return tf_example_decoder.TfExampleDecoder().decode(string_tensor)
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -80,11 +80,12 @@ class BatchQueue(object): ...@@ -80,11 +80,12 @@ class BatchQueue(object):
""" """
# Remember static shapes to set shapes of batched tensors. # Remember static shapes to set shapes of batched tensors.
static_shapes = collections.OrderedDict( static_shapes = collections.OrderedDict(
{key: tensor.get_shape() for key, tensor in tensor_dict.iteritems()}) {key: tensor.get_shape() for key, tensor in tensor_dict.items()})
# Remember runtime shapes to unpad tensors after batching. # Remember runtime shapes to unpad tensors after batching.
runtime_shapes = collections.OrderedDict( runtime_shapes = collections.OrderedDict(
{(key + rt_shape_str): tf.shape(tensor) {(key + rt_shape_str): tf.shape(tensor)
for key, tensor in tensor_dict.iteritems()}) for key, tensor in tensor_dict.iteritems()})
all_tensors = tensor_dict all_tensors = tensor_dict
all_tensors.update(runtime_shapes) all_tensors.update(runtime_shapes)
batched_tensors = tf.train.batch( batched_tensors = tf.train.batch(
...@@ -111,7 +112,7 @@ class BatchQueue(object): ...@@ -111,7 +112,7 @@ class BatchQueue(object):
# Separate input tensors from tensors containing their runtime shapes. # Separate input tensors from tensors containing their runtime shapes.
tensors = {} tensors = {}
shapes = {} shapes = {}
for key, batched_tensor in batched_tensors.iteritems(): for key, batched_tensor in batched_tensors.items():
unbatched_tensor_list = tf.unstack(batched_tensor) unbatched_tensor_list = tf.unstack(batched_tensor)
for i, unbatched_tensor in enumerate(unbatched_tensor_list): for i, unbatched_tensor in enumerate(unbatched_tensor_list):
if rt_shape_str in key: if rt_shape_str in key:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment