Commit 30aeec75 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2 from tensorflow/master

Sync to tensorflow-master
parents 68a18b70 78007443
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the Domain Adaptation via Style Transfer (PixelDA) model components.
A number of details in the implementation make reference to one of the following
works:
- "Unsupervised Representation Learning with Deep Convolutional
Generative Adversarial Networks""
https://arxiv.org/abs/1511.06434
This paper makes several architecture recommendations:
1. Use strided convs in discriminator, fractional-strided convs in generator
2. batchnorm everywhere
3. remove fully connected layers for deep models
4. ReLu for all layers in generator, except tanh on output
5. LeakyReLu for everything in discriminator
"""
import functools
import math
# Dependency imports
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
def create_model(hparams,
target_images,
source_images=None,
source_labels=None,
is_training=False,
noise=None,
num_classes=None):
"""Create a GAN model.
Arguments:
hparams: HParam object specifying model params
target_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_labels: A `Tensor` of size [batch_size] of categorical labels between
[0, num_classes]
is_training: whether model is currently training
noise: If None, model generates its own noise. Otherwise use provided.
num_classes: Number of classes for classification
Returns:
end_points dict with model outputs
Raises:
ValueError: unknown hparams.arch setting
"""
if num_classes is None and hparams.arch in ['resnet', 'simple']:
raise ValueError('Num classes must be provided to create task classifier')
if target_images.dtype != tf.float32:
raise ValueError('target_images must be tf.float32 and [-1, 1] normalized.')
if source_images is not None and source_images.dtype != tf.float32:
raise ValueError('source_images must be tf.float32 and [-1, 1] normalized.')
###########################
# Create latent variables #
###########################
latent_vars = dict()
if hparams.noise_channel:
noise_shape = [hparams.batch_size, hparams.noise_dims]
if noise is not None:
assert noise.shape.as_list() == noise_shape
tf.logging.info('Using provided noise')
else:
tf.logging.info('Using random noise')
noise = tf.random_uniform(
shape=noise_shape,
minval=-1,
maxval=1,
dtype=tf.float32,
name='random_noise')
latent_vars['noise'] = noise
####################
# Create generator #
####################
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose, slim.fully_connected],
normalizer_params=batch_norm_params(is_training,
hparams.batch_norm_decay),
weights_initializer=tf.random_normal_initializer(
stddev=hparams.normal_init_std),
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if hparams.arch == 'dcgan':
end_points = dcgan(
target_images, latent_vars, hparams, scope='generator')
elif hparams.arch == 'resnet':
end_points = resnet_generator(
source_images,
target_images.shape.as_list()[1:4],
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'residual_interpretation':
end_points = residual_interpretation_generator(
source_images, is_training=is_training, hparams=hparams)
elif hparams.arch == 'simple':
end_points = simple_generator(
source_images,
target_images,
is_training=is_training,
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'identity':
# Pass through unmodified, besides changing # channels
# Used to calculate baseline numbers
# Also set `generator_steps=0` for baseline
if hparams.generator_steps:
raise ValueError('Must set generator_steps=0 for identity arch. Is %s'
% hparams.generator_steps)
transferred_images = source_images
source_channels = source_images.shape.as_list()[-1]
target_channels = target_images.shape.as_list()[-1]
if source_channels == 1 and target_channels == 3:
transferred_images = tf.tile(source_images, [1, 1, 1, 3])
if source_channels == 3 and target_channels == 1:
transferred_images = tf.image.rgb_to_grayscale(source_images)
end_points = {'transferred_images': transferred_images}
else:
raise ValueError('Unknown architecture: %s' % hparams.arch)
#####################
# Domain Classifier #
#####################
if hparams.arch in [
'dcgan', 'resnet', 'residual_interpretation', 'simple', 'identity',
]:
# Add a discriminator for these architectures
end_points['transferred_domain_logits'] = predict_domain(
end_points['transferred_images'],
hparams,
is_training=is_training,
reuse=False)
end_points['target_domain_logits'] = predict_domain(
target_images,
hparams,
is_training=is_training,
reuse=True)
###################
# Task Classifier #
###################
if hparams.task_tower != 'none' and hparams.arch in [
'resnet', 'residual_interpretation', 'simple', 'identity',
]:
with tf.variable_scope('discriminator'):
with tf.variable_scope('task_tower'):
end_points['source_task_logits'], end_points[
'source_quaternion'] = pixelda_task_towers.add_task_specific_model(
source_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='source_task_classifier',
reuse_shared=False)
end_points['transferred_task_logits'], end_points[
'transferred_quaternion'] = (
pixelda_task_towers.add_task_specific_model(
end_points['transferred_images'],
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='transferred_task_classifier',
reuse_shared=True))
end_points['target_task_logits'], end_points[
'target_quaternion'] = pixelda_task_towers.add_task_specific_model(
target_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=True,
private_scope='transferred_task_classifier',
reuse_shared=True)
# Remove any endpoints with None values
return dict((k, v) for k, v in end_points.iteritems() if v is not None)
def batch_norm_params(is_training, batch_norm_decay):
return {
'is_training': is_training,
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
}
def lrelu(x, leakiness=0.2):
"""Relu, with optional leaky support."""
return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
def upsample(net, num_filters, scale=2, method='resize_conv', scope=None):
"""Performs spatial upsampling of the given features.
Args:
net: A `Tensor` of shape [batch_size, height, width, filters].
num_filters: The number of output filters.
scale: The scale of the upsampling. Must be a positive integer greater or
equal to two.
method: The method by which the features are upsampled. Valid options
include 'resize_conv' and 'conv2d_transpose'.
scope: An optional variable scope.
Returns:
A new set of features of shape
[batch_size, height*scale, width*scale, num_filters].
Raises:
ValueError: if `method` is not valid or
"""
if scale < 2:
raise ValueError('scale must be greater or equal to two.')
with tf.variable_scope(scope, 'upsample', [net]):
if method == 'resize_conv':
net = tf.image.resize_nearest_neighbor(
net, [net.shape.as_list()[1] * scale,
net.shape.as_list()[2] * scale],
align_corners=True,
name='resize')
return slim.conv2d(net, num_filters, stride=1, scope='conv')
elif method == 'conv2d_transpose':
return slim.conv2d_transpose(net, num_filters, scope='deconv')
else:
raise ValueError('Upsample method [%s] was not recognized.' % method)
def project_latent_vars(hparams, proj_shape, latent_vars, combine_method='sum'):
"""Generate noise and project to input volume size.
Args:
hparams: The hyperparameter HParams struct.
proj_shape: Shape to project noise (not including batch size).
latent_vars: dictionary of `'key': Tensor of shape [batch_size, N]`
combine_method: How to combine the projected values.
sum = project to volume then sum
concat = concatenate along last dimension (i.e. channel)
Returns:
If combine_method=sum, a `Tensor` of size `hparams.projection_shape`
If combine_method=concat and there are N latent vars, a `Tensor` of size
`hparams.projection_shape`, with the last channel multiplied by N
Raises:
ValueError: combine_method is not one of sum/concat
"""
values = []
for var in latent_vars:
with tf.variable_scope(var):
# Project & reshape noise to a HxWxC input
projected = slim.fully_connected(
latent_vars[var],
np.prod(proj_shape),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm)
values.append(tf.reshape(projected, [hparams.batch_size] + proj_shape))
if combine_method == 'sum':
result = values[0]
for value in values[1:]:
result += value
elif combine_method == 'concat':
# Concatenate along last axis
result = tf.concat(values, len(proj_shape))
else:
raise ValueError('Unknown combine_method %s' % combine_method)
tf.logging.info('Latent variables projected to size %s volume', result.shape)
return result
def resnet_block(net, hparams):
"""Create a resnet block."""
net_in = net
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=tf.nn.relu)
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=None)
if hparams.resnet_residuals:
net += net_in
return net
def resnet_stack(images, output_shape, hparams, scope=None):
"""Create a resnet style transfer block.
Args:
images: [batch-size, height, width, channels] image tensor to feed as input
output_shape: output image shape in form [height, width, channels]
hparams: hparams objects
scope: Variable scope
Returns:
Images after processing with resnet blocks.
"""
end_points = {}
if hparams.noise_channel:
# separate the noise for visualization
end_points['noise'] = images[:, :, :, -1]
assert images.shape.as_list()[1:3] == output_shape[0:2]
with tf.variable_scope(scope, 'resnet_style_transfer', [images]):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
kernel_size=[hparams.generator_kernel_size] * 2,
stride=1):
net = slim.conv2d(
images,
hparams.resnet_filters,
normalizer_fn=None,
activation_fn=tf.nn.relu)
for block in range(hparams.resnet_blocks):
net = resnet_block(net, hparams)
end_points['resnet_block_{}'.format(block)] = net
net = slim.conv2d(
net,
output_shape[-1],
kernel_size=[1, 1],
normalizer_fn=None,
activation_fn=tf.nn.tanh,
scope='conv_out')
end_points['transferred_images'] = net
return net, end_points
def predict_domain(images,
hparams,
is_training=False,
reuse=False,
scope='discriminator'):
"""Creates a discriminator for a GAN.
Args:
images: A `Tensor` of size [batch_size, height, width, channels]. It is
assumed that the images are centered between -1 and 1.
hparams: hparam object with params for discriminator
is_training: Specifies whether or not we're training or testing.
reuse: Whether to reuse variable scope
scope: An optional variable_scope.
Returns:
[batch size, 1] - logit output of discriminator.
"""
with tf.variable_scope(scope, 'discriminator', [images], reuse=reuse):
lrelu_partial = functools.partial(lrelu, leakiness=hparams.lrelu_leakiness)
with slim.arg_scope(
[slim.conv2d],
kernel_size=[hparams.discriminator_kernel_size] * 2,
activation_fn=lrelu_partial,
stride=2,
normalizer_fn=slim.batch_norm):
def add_noise(hidden, scope_num=None):
if scope_num:
hidden = slim.dropout(
hidden,
hparams.discriminator_dropout_keep_prob,
is_training=is_training,
scope='dropout_%s' % scope_num)
if hparams.discriminator_noise_stddev == 0:
return hidden
return hidden + tf.random_normal(
hidden.shape.as_list(),
mean=0.0,
stddev=hparams.discriminator_noise_stddev)
# As per the recommendation of the DCGAN paper, we don't use batch norm
# on the discriminator input (https://arxiv.org/pdf/1511.06434v2.pdf).
if hparams.discriminator_image_noise:
images = add_noise(images)
net = slim.conv2d(
images,
hparams.num_discriminator_filters,
normalizer_fn=None,
stride=hparams.discriminator_first_stride,
scope='conv1_stride%s' % hparams.discriminator_first_stride)
net = add_noise(net, 1)
block_id = 2
# Repeatedly stack
# discriminator_conv_block_size-1 conv layers with stride 1
# followed by a stride 2 layer
# Add (optional) noise at every point
while net.shape.as_list()[1] > hparams.projection_shape_size:
num_filters = int(hparams.num_discriminator_filters *
(hparams.discriminator_filter_factor**(block_id - 1)))
for conv_id in range(1, hparams.discriminator_conv_block_size):
net = slim.conv2d(
net,
num_filters,
stride=1,
scope='conv_%s_%s' % (block_id, conv_id))
if hparams.discriminator_do_pooling:
net = slim.conv2d(
net, num_filters, scope='conv_%s_prepool' % block_id)
net = slim.avg_pool2d(
net, kernel_size=[2, 2], stride=2, scope='pool_%s' % block_id)
else:
net = slim.conv2d(
net, num_filters, scope='conv_%s_stride2' % block_id)
net = add_noise(net, block_id)
block_id += 1
net = slim.flatten(net)
net = slim.fully_connected(
net,
1,
# Models with BN here generally produce noise
normalizer_fn=None,
activation_fn=None,
scope='fc_logit_out') # Returns logits!
return net
def dcgan_generator(images, output_shape, hparams, scope=None):
"""Transforms the visual style of the input images.
Args:
images: A `Tensor` of shape [batch_size, height, width, channels].
output_shape: A list or tuple of 3 elements: the output height, width and
number of channels.
hparams: hparams object with generator parameters
scope: Scope to place generator inside
Returns:
A `Tensor` of shape [batch_size, height, width, output_channels] which
represents the result of style transfer.
Raises:
ValueError: If `output_shape` is not a list or tuple or if it doesn't have
three elements or if `output_shape` or `images` arent square.
"""
if not isinstance(output_shape, (tuple, list)):
raise ValueError('output_shape must be a tuple or list.')
elif len(output_shape) != 3:
raise ValueError('output_shape must have three elements.')
if output_shape[0] != output_shape[1]:
raise ValueError('output_shape must be square')
if images.shape.as_list()[1] != images.shape.as_list()[2]:
raise ValueError('images height and width must match.')
outdim = output_shape[0]
indim = images.shape.as_list()[1]
num_iterations = int(math.ceil(math.log(float(outdim) / float(indim), 2.0)))
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose],
kernel_size=[hparams.generator_kernel_size] * 2,
stride=2):
with tf.variable_scope(scope or 'generator'):
net = images
# Repeatedly halve # filters until = hparams.decode_filters in last layer
for i in range(num_iterations):
num_filters = hparams.num_decoder_filters * 2**(num_iterations - i - 1)
net = slim.conv2d_transpose(net, num_filters, scope='deconv_%s' % i)
# Crop down to desired size (e.g. 32x32 -> 28x28)
dif = net.shape.as_list()[1] - outdim
low = dif / 2
high = net.shape.as_list()[1] - low
net = net[:, low:high, low:high, :]
# No batch norm on generator output
net = slim.conv2d(
net,
output_shape[2],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
return net
def dcgan(target_images, latent_vars, hparams, scope='dcgan'):
"""Creates the PixelDA model.
Args:
target_images: A `Tensor` of shape [batch_size, height, width, 3]
sampled from the image domain to which we want to transfer.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
hparams: The hyperparameter map.
scope: Surround generator component with this scope
Returns:
A dictionary of model outputs.
"""
proj_shape = [
hparams.projection_shape_size, hparams.projection_shape_size,
hparams.projection_shape_channels
]
source_volume = project_latent_vars(
hparams, proj_shape, latent_vars, combine_method='concat')
###################################################
# Transfer the source images to the target style. #
###################################################
with tf.variable_scope(scope, 'generator', [target_images]):
transferred_images = dcgan_generator(
source_volume,
output_shape=target_images.shape.as_list()[1:4],
hparams=hparams)
assert transferred_images.shape.as_list() == target_images.shape.as_list()
return {'transferred_images': transferred_images}
def resnet_generator(images, output_shape, hparams, latent_vars=None):
"""Creates a ResNet-based generator.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer
output_shape: A length-3 array indicating the height, width and channels of
the output.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
with tf.variable_scope('generator'):
if latent_vars:
noise_channel = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
images = tf.concat([images, noise_channel], 3)
transferred_images, end_points = resnet_stack(
images,
output_shape=output_shape,
hparams=hparams,
scope='resnet_stack')
end_points['transferred_images'] = transferred_images
return end_points
def residual_interpretation_block(images, hparams, scope):
"""Learns a residual image which is added to the incoming image.
Args:
images: A `Tensor` of size [batch_size, height, width, 3]
hparams: The hyperparameters struct.
scope: The name of the variable op scope.
Returns:
The updated images.
"""
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=None,
kernel_size=[hparams.generator_kernel_size] * 2):
net = images
for _ in range(hparams.res_int_convs):
net = slim.conv2d(
net, hparams.res_int_filters, activation_fn=tf.nn.relu)
net = slim.conv2d(net, 3, activation_fn=tf.nn.tanh)
# Add the residual
images += net
# Clip the output
images = tf.maximum(images, -1.0)
images = tf.minimum(images, 1.0)
return images
def residual_interpretation_generator(images,
is_training,
hparams,
latent_vars=None):
"""Creates a generator producing purely residual transformations.
A residual generator differs from the resnet generator in that each 'block' of
the residual generator produces a residual image. Consequently, the 'progress'
of the model generation process can be directly observed at inference time,
making it easier to diagnose and understand.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer. It is
assumed that the images are centered between -1 and 1.
is_training: whether or not the model is training.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
end_points = {}
with tf.variable_scope('generator'):
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [images.shape.as_list()[-1]],
latent_vars=latent_vars,
combine_method='sum')
images += projected_latent
with tf.variable_scope(None, 'residual_style_transfer', [images]):
for i in range(hparams.res_int_blocks):
images = residual_interpretation_block(images, hparams,
'residual_%d' % i)
end_points['transferred_images_%d' % i] = images
end_points['transferred_images'] = images
return end_points
def simple_generator(source_images, target_images, is_training, hparams,
latent_vars):
"""Simple generator architecture (stack of convs) for trying small models."""
end_points = {}
with tf.variable_scope('generator'):
feed_source_images = source_images
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=source_images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
feed_source_images = tf.concat([source_images, projected_latent], 3)
end_points = {}
###################################################
# Transfer the source images to the target style. #
###################################################
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
stride=1,
kernel_size=[hparams.generator_kernel_size] * 2):
net = feed_source_images
# N convolutions
for i in range(1, hparams.simple_num_conv_layers):
normalizer_fn = None
if i != 0:
normalizer_fn = slim.batch_norm
net = slim.conv2d(
net,
hparams.simple_conv_filters,
normalizer_fn=normalizer_fn,
activation_fn=tf.nn.relu)
# Project back to right # image channels
net = slim.conv2d(
net,
target_images.shape.as_list()[-1],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
transferred_images = net
assert transferred_images.shape.as_list() == target_images.shape.as_list()
end_points['transferred_images'] = transferred_images
return end_points
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for preprocessing the inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
def preprocess_classification(image, labels, is_training=False):
"""Preprocesses the image and labels for classification purposes.
Preprocessing includes shifting the images to be 0-centered between -1 and 1.
This is not only a popular method of preprocessing (inception) but is also
the mechanism used by DSNs.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
is_training: Whether or not we're training the model.
Returns:
The preprocessed image and labels.
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
return image, labels
def preprocess_style_transfer(image,
labels,
augment=False,
size=None,
is_training=False):
"""Preprocesses the image and labels for style transfer purposes.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
augment: Whether to apply data augmentation to inputs
size: The height and width to which images should be resized. If left as
`None`, then no resizing is performed
is_training: Whether or not we're training the model
Returns:
The preprocessed image and labels. Scaled to [-1, 1]
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
if augment and is_training:
image = image_augmentation(image)
if size:
image = resize_image(image, size)
image -= 0.5
image *= 2
return image, labels
def image_augmentation(image):
"""Performs data augmentation by randomly permuting the inputs.
Args:
image: A float `Tensor` of size [height, width, channels] with values
in range[0,1].
Returns:
The mutated batch of images
"""
# Apply photometric data augmentation (contrast etc.)
num_channels = image.shape_as_list()[-1]
if num_channels == 4:
# Only augment image part
image, depth = image[:, :, 0:3], image[:, :, 3:4]
elif num_channels == 1:
image = tf.image.grayscale_to_rgb(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.clip_by_value(image, 0, 1.0)
if num_channels == 4:
image = tf.concat(2, [image, depth])
elif num_channels == 1:
image = tf.image.rgb_to_grayscale(image)
return image
def resize_image(image, size=None):
"""Resize image to target size.
Args:
image: A `Tensor` of size [height, width, 3].
size: (height, width) to resize image to.
Returns:
resized image
"""
if size is None:
raise ValueError('Must specify size')
if image.shape_as_list()[:2] == size:
# Don't resize if not necessary
return image
image = tf.expand_dims(image, 0)
image = tf.image.resize_images(image, size)
image = tf.squeeze(image, 0)
return image
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
class PixelDAPreprocessTest(tf.test.TestCase):
def assert_preprocess_classification_is_centered(self, dtype, is_training):
tf.set_random_seed(0)
if dtype == tf.uint8:
image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
image = tf.cast(image, tf.uint8)
else:
image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
labels = {}
image, labels = pixelda_preprocess.preprocess_classification(
image, labels, is_training=is_training)
with self.test_session() as sess:
np_image = sess.run(image)
self.assertTrue(np_image.min() <= -0.95)
self.assertTrue(np_image.min() >= -1.0)
self.assertTrue(np_image.max() >= 0.95)
self.assertTrue(np_image.max() <= 1.0)
def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=True)
def testPreprocessClassificationZeroCentersUint8DuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=False)
def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=True)
def testPreprocessClassificationZeroCentersFloatDuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task towers for PixelDA model."""
import tensorflow as tf
slim = tf.contrib.slim
def add_task_specific_model(images,
hparams,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope=None):
"""Create a classifier for the given images.
The classifier is composed of a few 'private' layers followed by a few
'shared' layers. This lets us account for different image 'style', while
sharing the last few layers as 'content' layers.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
hparams: model hparams
num_classes: The number of output classes.
is_training: whether model is training
reuse_private: Whether or not to reuse the private weights, which are the
first few layers in the classifier
private_scope: The name of the variable_scope for the private (unshared)
components of the classifier.
reuse_shared: Whether or not to reuse the shared weights, which are the last
few layers in the classifier
shared_scope: The name of the variable_scope for the shared components of
the classifier.
Returns:
The logits, a `Tensor` of shape [batch_size, num_classes].
Raises:
ValueError: If hparams.task_classifier is an unknown value
"""
model = hparams.task_tower
# Make sure the classifier name shows up in graph
shared_scope = shared_scope or (model + '_shared')
kwargs = {
'num_classes': num_classes,
'is_training': is_training,
'reuse_private': reuse_private,
'reuse_shared': reuse_shared,
}
if private_scope:
kwargs['private_scope'] = private_scope
if shared_scope:
kwargs['shared_scope'] = shared_scope
quaternion_pred = None
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay_task_classifier)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if model == 'doubling_pose_estimator':
logits, quaternion_pred = doubling_cnn_class_and_quaternion(
images, num_private_layers=hparams.num_private_layers, **kwargs)
elif model == 'mnist':
logits, _ = mnist_classifier(images, **kwargs)
elif model == 'svhn':
logits, _ = svhn_classifier(images, **kwargs)
elif model == 'gtsrb':
logits, _ = gtsrb_classifier(images, **kwargs)
elif model == 'pose_mini':
logits, quaternion_pred = pose_mini_tower(images, **kwargs)
else:
raise ValueError('Unknown task classifier %s' % model)
return logits, quaternion_pred
#####################################
# Classifiers used in the DSN paper #
#####################################
def mnist_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope='mnist',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional MNIST model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits, endpoints = conv_mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool2']), 100, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def svhn_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional SVHN model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['conv3']), 3072, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def gtsrb_classifier(images,
is_training=False,
num_classes=43,
reuse_private=False,
private_scope='gtsrb',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional GTSRB model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
reuse_private: Whether or not to reuse the private components of the model.
private_scope: The name of the private scope.
reuse_shared: Whether or not to reuse the shared components of the model.
shared_scope: The name of the shared scope.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
net['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, net
#########################
# pose_mini task towers #
#########################
def pose_mini_tower(images,
num_classes=11,
is_training=False,
reuse_private=False,
private_scope='pose_mini',
reuse_shared=False,
shared_scope='task_model'):
"""Task tower for the pose_mini dataset."""
with tf.variable_scope(private_scope, reuse=reuse_private):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
return logits, quaternion_pred
def doubling_cnn_class_and_quaternion(images,
num_private_layers=1,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope='doubling_cnn',
reuse_shared=False,
shared_scope='task_model'):
"""Alternate conv, pool while doubling filter count."""
net = images
depth = 32
layer_id = 1
with tf.variable_scope(private_scope, reuse=reuse_private):
while num_private_layers > 0 and net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
num_private_layers -= 1
with tf.variable_scope(shared_scope, reuse=reuse_shared):
while net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
net = slim.flatten(net)
net = slim.fully_connected(net, 100, scope='fc1')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc_logits')
return logits, quaternion_pred
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the PixelDA model."""
from functools import partial
import os
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_string('train_log_dir', '/tmp/pixelda/',
'Directory where to write event logs.')
flags.DEFINE_integer(
'save_summaries_steps', 500,
'The frequency with which summaries are saved, in seconds.')
flags.DEFINE_integer('save_interval_secs', 300,
'The frequency with which the model is saved, in seconds.')
flags.DEFINE_boolean('summarize_gradients', False,
'Whether to summarize model gradients')
flags.DEFINE_integer(
'print_loss_steps', 100,
'The frequency with which the losses are printed, in steps.')
flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.'
' If hparams="arch=dcgan", this flag is ignored.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string('source_split_name', 'train',
'Name of the train split for the source.')
flags.DEFINE_string('target_split_name', 'train',
'Name of the train split for the target.')
flags.DEFINE_string('dataset_dir', '',
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def _get_vars_and_update_ops(hparams, scope):
"""Returns the variables and update ops for a particular variable scope.
Args:
hparams: The hyperparameters struct.
scope: The variable scope.
Returns:
A tuple consisting of trainable variables and update ops.
"""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = filter(is_trainable, slim.get_model_variables(scope))
global_step = slim.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
tf.logging.info('All variables for scope: %s',
slim.get_model_variables(scope))
tf.logging.info('Trainable variables for scope: %s', var_list)
return var_list, update_ops
def _train(discriminator_train_op,
generator_train_op,
logdir,
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
hparams=None):
"""Runs the training loop.
Args:
discriminator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the discriminator.
generator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the generator.
logdir: The directory where the graph and checkpoints are saved.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
hparams: The hparams struct.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
`save_summaries_steps` are `None.
"""
global_step = slim.get_or_create_global_step()
scaffold = scaffold or tf.train.Scaffold()
hooks = hooks or []
if is_chief:
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, checkpoint_dir=logdir, master=master)
if chief_only_hooks:
hooks.extend(chief_only_hooks)
hooks.append(tf.train.StepCounterHook(output_dir=logdir))
if save_summaries_steps:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_summaries_steps is None')
hooks.append(
tf.train.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=logdir))
if save_checkpoint_secs:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_checkpoint_secs is None')
hooks.append(
tf.train.CheckpointSaverHook(
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
else:
session_creator = tf.train.WorkerSessionCreator(
scaffold=scaffold, master=master)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
loss = None
while not session.should_stop():
# Run the domain classifier op X times.
for _ in range(hparams.discriminator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run(
[discriminator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step,
loss)
# Run the generator op X times.
for _ in range(hparams.generator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run([generator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step,
loss)
return loss
def run_training(run_dir, checkpoint_dir, hparams):
"""Runs the training loop.
Args:
run_dir: The directory where training specific logs are placed
checkpoint_dir: The directory where the checkpoints and log files are
stored.
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for path in [run_dir, checkpoint_dir]:
if not tf.gfile.Exists(path):
tf.gfile.MakeDirs(path)
# Serialize hparams to log dir
hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
with tf.gfile.FastGFile(hparams_filename, 'w') as f:
f.write(hparams.to_json())
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
target_images, _ = dataset_factory.provide_batch(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
# Data provider provides 1 hot labels, but we expect categorical.
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Source and Target datasets must have same number of classes. '
'Are %d and %d' % (num_source_classes, num_target_classes))
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=True,
num_classes=num_target_classes)
#################################
# Get the variables to optimize #
#################################
generator_vars, generator_update_ops = _get_vars_and_update_ops(
hparams, 'generator')
discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
hparams, 'discriminator')
########################
# Configure the losses #
########################
generator_loss = pixelda_losses.g_step_loss(
source_images,
source_labels,
end_points,
hparams,
num_classes=num_target_classes)
discriminator_loss = pixelda_losses.d_step_loss(
end_points, source_labels, num_target_classes, hparams)
###########################
# Create the training ops #
###########################
learning_rate = hparams.learning_rate
if hparams.lr_decay_steps:
learning_rate = tf.train.exponential_decay(
learning_rate,
slim.get_or_create_global_step(),
decay_steps=hparams.lr_decay_steps,
decay_rate=hparams.lr_decay_rate,
staircase=True)
tf.summary.scalar('Learning_rate', learning_rate)
if hparams.discriminator_steps == 0:
discriminator_train_op = tf.no_op()
else:
discriminator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
discriminator_train_op = slim.learning.create_train_op(
discriminator_loss,
discriminator_optimizer,
update_ops=discriminator_update_ops,
variables_to_train=discriminator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
if hparams.generator_steps == 0:
generator_train_op = tf.no_op()
else:
generator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
generator_train_op = slim.learning.create_train_op(
generator_loss,
generator_optimizer,
update_ops=generator_update_ops,
variables_to_train=generator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
#############
# Summaries #
#############
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summaries_color_distributions(end_points['transferred_images'],
'Transferred')
pixelda_utils.summaries_color_distributions(target_images, 'Target')
if source_images is not None:
pixelda_utils.summarize_transferred(source_images,
end_points['transferred_images'])
pixelda_utils.summaries_color_distributions(source_images, 'Source')
pixelda_utils.summaries_color_distributions(
tf.abs(source_images - end_points['transferred_images']),
'Abs(Source_minus_Transferred)')
number_of_steps = None
if hparams.num_training_examples:
# Want to control by amount of data seen, not # steps
number_of_steps = hparams.num_training_examples / hparams.batch_size
hooks = [tf.train.StepCounterHook(),]
chief_only_hooks = [
tf.train.CheckpointSaverHook(
saver=tf.train.Saver(),
checkpoint_dir=run_dir,
save_secs=FLAGS.save_interval_secs)
]
if number_of_steps:
hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps))
_train(
discriminator_train_op,
generator_train_op,
logdir=run_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=None,
save_summaries_steps=FLAGS.save_summaries_steps,
hparams=hparams)
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_training(
run_dir=FLAGS.train_log_dir,
checkpoint_dir=FLAGS.train_log_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for PixelDA model."""
import math
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
def remove_depth(images):
"""Takes a batch of images and remove depth channel if present."""
if images.shape.as_list()[-1] == 4:
return images[:, :, :, 0:3]
return images
def image_grid(images, max_grid_size=4):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
images = remove_depth(images)
batch_size = images.shape.as_list()[0]
grid_size = min(int(math.sqrt(batch_size)), max_grid_size)
assert images.shape.as_list()[0] >= grid_size * grid_size
# If we have a depth channel
if images.shape.as_list()[-1] == 4:
images = images[:grid_size * grid_size, :, :, 0:3]
depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
split = tf.split(0, grid_size, images)
depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
depth_split = tf.split(0, grid_size, depth)
grid = tf.concat(split + depth_split, 1)
return tf.expand_dims(grid, 0)
else:
images = images[:grid_size * grid_size, :, :, :]
images = tf.reshape(
images, [-1, images.shape.as_list()[2],
images.shape.as_list()[3]])
split = tf.split(images, grid_size, 0)
grid = tf.concat(split, 1)
return tf.expand_dims(grid, 0)
def source_and_output_image_grid(output_images,
source_images=None,
max_grid_size=4):
"""Create NxN image grid for output, concatenate source grid if given.
Makes grid out of output_images and, if provided, source_images, and
concatenates them.
Args:
output_images: [batch_size, h, w, c] tensor of images
source_images: optional[batch_size, h, w, c] tensor of images
max_grid_size: Image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
output_grid = image_grid(output_images, max_grid_size=max_grid_size)
if source_images is not None:
source_grid = image_grid(source_images, max_grid_size=max_grid_size)
# Make sure they have the same # of channels before concat
# Assumes either 1 or 3 channels
if output_grid.shape.as_list()[-1] != source_grid.shape.as_list()[-1]:
if output_grid.shape.as_list()[-1] == 1:
output_grid = tf.tile(output_grid, [1, 1, 1, 3])
if source_grid.shape.as_list()[-1] == 1:
source_grid = tf.tile(source_grid, [1, 1, 1, 3])
output_grid = tf.concat([output_grid, source_grid], 1)
return output_grid
def summarize_model(end_points):
"""Summarizes the given model via its end_points.
Args:
end_points: A dictionary of end_point names to `Tensor`.
"""
tf.summary.histogram('domain_logits_transferred',
tf.sigmoid(end_points['transferred_domain_logits']))
tf.summary.histogram('domain_logits_target',
tf.sigmoid(end_points['target_domain_logits']))
def summarize_transferred_grid(transferred_images,
source_images=None,
name='Transferred'):
"""Produces a visual grid summarization of the image transferrence.
Args:
transferred_images: A `Tensor` of size [batch_size, height, width, c].
source_images: A `Tensor` of size [batch_size, height, width, c].
name: Name to use in summary name
"""
if source_images is not None:
grid = source_and_output_image_grid(transferred_images, source_images)
else:
grid = image_grid(transferred_images)
tf.summary.image('%s_Images_Grid' % name, grid, max_outputs=1)
def summarize_transferred(source_images,
transferred_images,
max_images=20,
name='Transferred'):
"""Produces a visual summary of the image transferrence.
This summary displays the source image, transferred image, and a grayscale
difference image which highlights the differences between input and output.
Args:
source_images: A `Tensor` of size [batch_size, height, width, channels].
transferred_images: A `Tensor` of size [batch_size, height, width, channels]
max_images: The number of images to show.
name: Name to use in summary name
Raises:
ValueError: If number of channels in source and target are incompatible
"""
source_channels = source_images.shape.as_list()[-1]
transferred_channels = transferred_images.shape.as_list()[-1]
if source_channels < transferred_channels:
if source_channels != 1:
raise ValueError(
'Source must be 1 channel or same # of channels as target')
source_images = tf.tile(source_images, [1, 1, 1, transferred_channels])
if transferred_channels < source_channels:
if transferred_channels != 1:
raise ValueError(
'Target must be 1 channel or same # of channels as source')
transferred_images = tf.tile(transferred_images, [1, 1, 1, source_channels])
diffs = tf.abs(source_images - transferred_images)
diffs = tf.reduce_max(diffs, reduction_indices=[3], keep_dims=True)
diffs = tf.tile(diffs, [1, 1, 1, max(source_channels, transferred_channels)])
transition_images = tf.concat([
source_images,
transferred_images,
diffs,
], 2)
tf.summary.image(
'%s_difference' % name, transition_images, max_outputs=max_images)
def summaries_color_distributions(images, name):
"""Produces a histogram of the color distributions of the images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
tf.summary.histogram('color_values/%s' % name, images)
def summarize_images(images, name):
"""Produces a visual summary of the given images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
grid = image_grid(images)
tf.summary.image('%s_Images' % name, grid, max_outputs=1)
......@@ -322,7 +322,7 @@ bazel-bin/im2txt/run_inference \
Example output:
```shell
```
Captions for image COCO_val2014_000000224477.jpg:
0) a man riding a wave on top of a surfboard . (p=0.040413)
1) a person riding a surf board on a wave (p=0.017452)
......
......@@ -76,7 +76,7 @@ if __name__ == '__main__':
basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1)
original_filename = os.path.join(data_dir, basename)
if not os.path.exists(original_filename):
print('Failed to find: ' % original_filename)
print('Failed to find: %s' % original_filename)
sys.exit(-1)
new_filename = os.path.join(data_dir, labels[i], basename)
os.rename(original_filename, new_filename)
# LFADS - Latent Factor Analysis via Dynamical Systems
This code implements the model from the paper "[LFADS - Latent Factor Analysis via Dynamical Systems](http://biorxiv.org/content/early/2017/06/20/152884)". It is a sequential variational auto-encoder designed specifically for investigating neuroscience data, but can be applied widely to any time series data. In an unsupervised setting, LFADS is able to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observation model is a loss on a probability distribution, so when LFADS processes a dataset, a denoised version of the dataset is also created. For example, if the dataset is raw spike counts, then under the negative log-likeihood loss under a Poisson distribution, the denoised data would be the inferred Poisson rates.
## Prerequisites
The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.2.1 ([install](https://www.tensorflow.org/install/)) -
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py))
## Getting started
Before starting, run the following:
<pre>
$ export PYTHONPATH=$PYTHONPATH:/<b>path/to/your/directory</b>/lfads/
</pre>
where "path/to/your/directory" is replaced with the path to the LFADS repository (you can get this path by using the `pwd` command). This allows the nested directories to access modules from their parent directory.
## Generate synthetic data
In order to generate the synthetic datasets first, from the top-level lfads directory, run:
```sh
$ cd synth_data
$ ./run_generate_synth_data.sh
$ cd ..
```
These synthetic datasets are provided 1. to gain insight into how the LFADS algorithm operates, and 2. to give reasonable starting points for analyses you might be interested for your own data.
## Train an LFADS model
Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 9) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
```sh
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with spiking noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=poisson \
--do_causal_controller=false \
--batch_size=128 \
--learning_rate_init=0.01 \
--learning_rate_stop=1e-05 \
--learning_rate_decay_factor=0.95 \
--learning_rate_n_to_compare=6 \
--do_reset_learning_rate=false \
--keep_prob=0.95 \
--con_dim=128 \
--gen_dim=200 \
--ci_enc_dim=128 \
--ic_dim=64 \
--ic_enc_dim=128 \
--ic_prior_var_min=0.1 \
--gen_cell_input_weight_scale=1.0 \
--cell_weight_scale=1.0 \
--do_feed_factors_to_controller=true \
--kl_start_step=0 \
--kl_increase_steps=2000 \
--kl_ic_weight=1.0 \
--l2_con_scale=0.0 \
--l2_gen_scale=2000.0 \
--l2_start_step=0 \
--l2_increase_steps=2000 \
--ic_prior_var_scale=0.1 \
--ic_post_var_min=0.0001 \
--kl_co_weight=1.0 \
--prior_ar_nvar=0.1 \
--cell_clip_value=5.0 \
--max_ckpt_to_keep_lve=5 \
--do_train_prior_ar_atau=true \
--co_prior_var_scale=0.1 \
--csv_log=fitlog \
--feedback_factors_or_rates=factors \
--do_train_prior_ar_nvar=true \
--max_grad_norm=200.0 \
--device=gpu:0 \
--num_steps_for_gen_ic=100000000 \
--ps_nexamples_to_process=100000000 \
--checkpoint_name=lfads_vae \
--temporal_spike_jitter_width=0 \
--checkpoint_pb_load_name=checkpoint \
--inject_ext_input_to_gen=false \
--co_mean_corr_scale=0.0 \
--gen_cell_rec_weight_scale=1.0 \
--max_ckpt_to_keep=5 \
--output_filename_stem="" \
--ic_prior_var_max=0.1 \
--prior_ar_atau=10.0 \
--do_train_io_only=false
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--output_dist=poisson
# Run LFADS on multi-session RNN data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_multisession \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
--factors_dim=10 \
--output_dist=poisson
# Run LFADS on integration to bound model data
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=itb_rnn \
--lfads_save_dir=/tmp/lfads_itb_rnn \
--co_dim=1 \
--factors_dim=20 \
--controller_input_lag=0 \
--output_dist=poisson
# Run LFADS on chaotic RNN data with labels
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnns_labeled \
--lfads_save_dir=/tmp/lfads_chaotic_rnns_labeled \
--co_dim=0 \
--factors_dim=20 \
--controller_input_lag=0 \
--ext_input_dim=1 \
--output_dist=poisson
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=gaussian \
```
**Tip**: If you are running LFADS on GPU and would like to run more than one model concurrently, set the `--allow_gpu_growth=True` flag on each job, otherwise one model will take up the entire GPU for performance purposes. Also, one needs to install the TensorFlow libraries with GPU support.
## Visualize a training model
To visualize training curves and various other metrics while training and LFADS model, run the following command on your model directory. To launch a tensorboard on the chaotic RNN data with input pulses, for example:
```sh
tensorboard --logdir=/tmp/lfads_chaotic_rnn_inputs_g2p5
```
## Evaluate a trained model
Once your model is finished training, there are multiple ways you can evaluate
it. Below are some sample commands to evaluate an LFADS model trained on the
chaotic rnn data with input pulses (g = 2.5). The key differences here are
setting the `--kind` flag to the appropriate mode, as well as the
`--checkpoint_pb_load_name` flag to `checkpoint_lve` and the `--batch_size` flag
(if you'd like to make it larger or smaller). All other flags should be the
same as used in training, so that the same model architecture is built.
```sh
# Take samples from posterior then average (denoising operation)
$ python run_lfads.py --kind=posterior_sample_and_average \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=1024 \
--checkpoint_pb_load_name=checkpoint_lve
# Sample from prior (generation of completely new samples)
$ python run_lfads.py --kind=prior_sample \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--batch_size=50 \
--checkpoint_pb_load_name=checkpoint_lve
# Write down model parameters
$ python run_lfads.py --kind=write_model_params \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--checkpoint_pb_load_name=checkpoint_lve
```
## Contact
File any issues with the [issue tracker](https://github.com/tensorflow/models/issues). For any questions or problems, this code is maintained by [@sussillo](https://github.com/sussillo) and [@jazcollins](https://github.com/jazcollins).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import tensorflow as tf
from utils import linear, log_sum_exp
class Poisson(object):
"""Poisson distributon
Computes the log probability under the model.
"""
def __init__(self, log_rates):
""" Create Poisson distributions with log_rates parameters.
Args:
log_rates: a tensor-like list of log rates underlying the Poisson dist.
"""
self.logr = log_rates
def logp(self, bin_counts):
"""Compute the log probability for the counts in the bin, under the model.
Args:
bin_counts: array-like integer counts
Returns:
The log-probability under the Poisson models for each element of
bin_counts.
"""
k = tf.to_float(bin_counts)
# log poisson(k, r) = log(r^k * e^(-r) / k!) = k log(r) - r - log k!
# log poisson(k, r=exp(x)) = k * x - exp(x) - lgamma(k + 1)
return k * self.logr - tf.exp(self.logr) - tf.lgamma(k + 1)
def diag_gaussian_log_likelihood(z, mu=0.0, logvar=0.0):
"""Log-likelihood under a Gaussian distribution with diagonal covariance.
Returns the log-likelihood for each dimension. One should sum the
results for the log-likelihood under the full multidimensional model.
Args:
z: The value to compute the log-likelihood.
mu: The mean of the Gaussian
logvar: The log variance of the Gaussian.
Returns:
The log-likelihood under the Gaussian model.
"""
return -0.5 * (logvar + np.log(2*np.pi) + \
tf.square((z-mu)/tf.exp(0.5*logvar)))
def gaussian_pos_log_likelihood(unused_mean, logvar, noise):
"""Gaussian log-likelihood function for a posterior in VAE
Note: This function is specialized for a posterior distribution, that has the
form of z = mean + sigma * noise.
Args:
unused_mean: ignore
logvar: The log variance of the distribution
noise: The noise used in the sampling of the posterior.
Returns:
The log-likelihood under the Gaussian model.
"""
# ln N(z; mean, sigma) = - ln(sigma) - 0.5 ln 2pi - noise^2 / 2
return - 0.5 * (logvar + np.log(2 * np.pi) + tf.square(noise))
class Gaussian(object):
"""Base class for Gaussian distribution classes."""
pass
class DiagonalGaussian(Gaussian):
"""Diagonal Gaussian with different constant mean and variances in each
dimension.
"""
def __init__(self, batch_size, z_size, mean, logvar):
"""Create a diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
mean: The N-D mean of the distribution.
logvar: The N-D log variance of the diagonal distribution.
"""
size__xz = [None, z_size]
self.mean = mean # bxn already
self.logvar = logvar # bxn already
self.noise = noise = tf.random_normal(tf.shape(logvar))
self.sample = mean + tf.exp(0.5 * logvar) * noise
mean.set_shape(size__xz)
logvar.set_shape(size__xz)
self.sample.set_shape(size__xz)
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample:
return gaussian_pos_log_likelihood(self.mean, self.logvar, self.noise)
return diag_gaussian_log_likelihood(z, self.mean, self.logvar)
class LearnableDiagonalGaussian(Gaussian):
"""Diagonal Gaussian whose mean and variance are learned parameters."""
def __init__(self, batch_size, z_size, name, mean_init=0.0,
var_init=1.0, var_min=0.0, var_max=1000000.0):
"""Create a learnable diagonal gaussian distribution.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
name: prefix name for the mean and log TF variables.
mean_init (optional): The N-D mean initialization of the distribution.
var_init (optional): The N-D variance initialization of the diagonal
distribution.
var_min (optional): The minimum value the learned variance can take in any
dimension.
var_max (optional): The maximum value the learned variance can take in any
dimension.
"""
size_1xn = [1, z_size]
size__xn = [None, z_size]
size_bx1 = tf.stack([batch_size, 1])
assert var_init > 0.0, "Problems"
assert var_max >= var_min, "Problems"
assert var_init >= var_min, "Problems"
assert var_max >= var_init, "Problems"
z_mean_1xn = tf.get_variable(name=name+"/mean", shape=size_1xn,
initializer=tf.constant_initializer(mean_init))
self.mean_bxn = mean_bxn = tf.tile(z_mean_1xn, size_bx1)
mean_bxn.set_shape(size__xn) # tile loses shape
log_var_init = np.log(var_init)
if var_max > var_min:
var_is_trainable = True
else:
var_is_trainable = False
z_logvar_1xn = \
tf.get_variable(name=(name+"/logvar"), shape=size_1xn,
initializer=tf.constant_initializer(log_var_init),
trainable=var_is_trainable)
if var_is_trainable:
z_logit_var_1xn = tf.exp(z_logvar_1xn)
z_var_1xn = tf.nn.sigmoid(z_logit_var_1xn)*(var_max-var_min) + var_min
z_logvar_1xn = tf.log(z_var_1xn)
logvar_bxn = tf.tile(z_logvar_1xn, size_bx1)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(tf.shape(logvar_bxn))
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn, self.logvar_bxn,
self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class DiagonalGaussianFromInput(Gaussian):
"""Diagonal Gaussian whose mean and variance are conditioned on other
variables.
Note: the parameters to convert from input to the learned mean and log
variance are held in this class.
"""
def __init__(self, x_bxu, z_size, name, var_min=0.0):
"""Create an input dependent diagonal Gaussian distribution.
Args:
x: The input tensor from which the mean and variance are computed,
via a linear transformation of x. I.e.
mu = Wx + b, log(var) = Mx + c
z_size: The size of the distribution.
name: The name to prefix to learned variables.
var_min (optional): Minimal variance allowed. This is an additional
way to control the amount of information getting through the stochastic
layer.
"""
size_bxn = tf.stack([tf.shape(x_bxu)[0], z_size])
self.mean_bxn = mean_bxn = linear(x_bxu, z_size, name=(name+"/mean"))
logvar_bxn = linear(x_bxu, z_size, name=(name+"/logvar"))
if var_min > 0.0:
logvar_bxn = tf.log(tf.exp(logvar_bxn) + var_min)
self.logvar_bxn = logvar_bxn
self.noise_bxn = noise_bxn = tf.random_normal(size_bxn)
self.noise_bxn.set_shape([None, z_size])
self.sample_bxn = mean_bxn + tf.exp(0.5 * logvar_bxn) * noise_bxn
def logp(self, z=None):
"""Compute the log-likelihood under the distribution.
Args:
z (optional): value to compute likelihood for, if None, use sample.
Returns:
The likelihood of z under the model.
"""
if z is None:
z = self.sample
# This is needed to make sure that the gradients are simple.
# The value of the function shouldn't change.
if z == self.sample_bxn:
return gaussian_pos_log_likelihood(self.mean_bxn,
self.logvar_bxn, self.noise_bxn)
return diag_gaussian_log_likelihood(z, self.mean_bxn, self.logvar_bxn)
@property
def mean(self):
return self.mean_bxn
@property
def logvar(self):
return self.logvar_bxn
@property
def sample(self):
return self.sample_bxn
class GaussianProcess:
"""Base class for Gaussian processes."""
pass
class LearnableAutoRegressive1Prior(GaussianProcess):
"""AR(1) model where autocorrelation and process variance are learned
parameters. Assumed zero mean.
"""
def __init__(self, batch_size, z_size,
autocorrelation_taus, noise_variances,
do_train_prior_ar_atau, do_train_prior_ar_nvar,
num_steps, name):
"""Create a learnable autoregressive (1) process.
Args:
batch_size: The size of the batch, i.e. 0th dim in 2D tensor of samples.
z_size: The dimension of the distribution, i.e. 1st dim in 2D tensor.
autocorrelation_taus: The auto correlation time constant of the AR(1)
process.
A value of 0 is uncorrelated gaussian noise.
noise_variances: The variance of the additive noise, *not* the process
variance.
do_train_prior_ar_atau: Train or leave as constant, the autocorrelation?
do_train_prior_ar_nvar: Train or leave as constant, the noise variance?
num_steps: Number of steps to run the process.
name: The name to prefix to learned TF variables.
"""
# Note the use of the plural in all of these quantities. This is intended
# to mark that even though a sample z_t from the posterior is thought of a
# single sample of a multidimensional gaussian, the prior is actually
# thought of as U AR(1) processes, where U is the dimension of the inferred
# input.
size_bx1 = tf.stack([batch_size, 1])
size__xu = [None, z_size]
# process variance, the variance at time t over all instantiations of AR(1)
# with these parameters.
log_evar_inits_1xu = tf.expand_dims(tf.log(noise_variances), 0)
self.logevars_1xu = logevars_1xu = \
tf.Variable(log_evar_inits_1xu, name=name+"/logevars", dtype=tf.float32,
trainable=do_train_prior_ar_nvar)
self.logevars_bxu = logevars_bxu = tf.tile(logevars_1xu, size_bx1)
logevars_bxu.set_shape(size__xu) # tile loses shape
# \tau, which is the autocorrelation time constant of the AR(1) process
log_atau_inits_1xu = tf.expand_dims(tf.log(autocorrelation_taus), 0)
self.logataus_1xu = logataus_1xu = \
tf.Variable(log_atau_inits_1xu, name=name+"/logatau", dtype=tf.float32,
trainable=do_train_prior_ar_atau)
# phi in x_t = \mu + phi x_tm1 + \eps
# phi = exp(-1/tau)
# phi = exp(-1/exp(logtau))
# phi = exp(-exp(-logtau))
phis_1xu = tf.exp(-tf.exp(-logataus_1xu))
self.phis_bxu = phis_bxu = tf.tile(phis_1xu, size_bx1)
phis_bxu.set_shape(size__xu)
# process noise
# pvar = evar / (1- phi^2)
# logpvar = log ( exp(logevar) / (1 - phi^2) )
# logpvar = logevar - log(1-phi^2)
# logpvar = logevar - (log(1-phi) + log(1+phi))
self.logpvars_1xu = \
logevars_1xu - tf.log(1.0-phis_1xu) - tf.log(1.0+phis_1xu)
self.logpvars_bxu = logpvars_bxu = tf.tile(self.logpvars_1xu, size_bx1)
logpvars_bxu.set_shape(size__xu)
# process mean (zero but included in for completeness)
self.pmeans_bxu = pmeans_bxu = tf.zeros_like(phis_bxu)
# For sampling from the prior during de-novo generation.
self.means_t = means_t = [None] * num_steps
self.logvars_t = logvars_t = [None] * num_steps
self.samples_t = samples_t = [None] * num_steps
self.gaussians_t = gaussians_t = [None] * num_steps
sample_bxu = tf.zeros_like(phis_bxu)
for t in range(num_steps):
# process variance used here to make process completely stationary
if t == 0:
logvar_pt_bxu = self.logpvars_bxu
else:
logvar_pt_bxu = self.logevars_bxu
z_mean_pt_bxu = pmeans_bxu + phis_bxu * sample_bxu
gaussians_t[t] = DiagonalGaussian(batch_size, z_size,
mean=z_mean_pt_bxu,
logvar=logvar_pt_bxu)
sample_bxu = gaussians_t[t].sample
samples_t[t] = sample_bxu
logvars_t[t] = logvar_pt_bxu
means_t[t] = z_mean_pt_bxu
def logp_t(self, z_t_bxu, z_tm1_bxu=None):
"""Compute the log-likelihood under the distribution for a given time t,
not the whole sequence.
Args:
z_t_bxu: sample to compute likelihood for at time t.
z_tm1_bxu (optional): sample condition probability of z_t upon.
Returns:
The likelihood of p_t under the model at time t. i.e.
p(z_t|z_tm1) = N(z_tm1 * phis, eps^2)
"""
if z_tm1_bxu is None:
return diag_gaussian_log_likelihood(z_t_bxu, self.pmeans_bxu,
self.logpvars_bxu)
else:
means_t_bxu = self.pmeans_bxu + self.phis_bxu * z_tm1_bxu
logp_tgtm1_bxu = diag_gaussian_log_likelihood(z_t_bxu,
means_t_bxu,
self.logevars_bxu)
return logp_tgtm1_bxu
class KLCost_GaussianGaussian(object):
"""log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian prior. See
eqn 10 and Appendix B in VAE for latter term,
http://arxiv.org/abs/1312.6114
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, zs, prior_zs):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
E_q[ln p(z_i | z_{i+1}) / q(z_i | x)
\int q(z) ln p(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_p^2) + \
sigma_q^2 / sigma_p^2 + (mean_p - mean_q)^2 / sigma_p^2)
\int q(z) ln q(z) dz = - 0.5 ln(2pi) - 0.5 \sum (ln(sigma_q^2) + 1)
Args:
zs: posterior z ~ q(z|x)
prior_zs: prior zs
"""
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
kl_b = 0.0
for z, prior_z in zip(zs, prior_zs):
assert isinstance(z, Gaussian)
assert isinstance(prior_z, Gaussian)
# ln(2pi) terms cancel
kl_b += 0.5 * tf.reduce_sum(
prior_z.logvar - z.logvar
+ tf.exp(z.logvar - prior_z.logvar)
+ tf.square((z.mean - prior_z.mean) / tf.exp(0.5 * prior_z.logvar))
- 1.0, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
class KLCost_GaussianGaussianProcessSampled(object):
""" log p(x|z) + KL(q||p) terms for Gaussian posterior and Gaussian process
prior via sampling.
The log p(x|z) term is the reconstruction error under the model.
The KL term represents the penalty for passing information from the encoder
to the decoder.
To sample KL(q||p), we simply sample
ln q - ln p
by drawing samples from q and averaging.
"""
def __init__(self, post_zs, prior_z_process):
"""Create a lower bound in three parts, normalized reconstruction
cost, normalized KL divergence cost, and their sum.
Args:
post_zs: posterior z ~ q(z|x)
prior_z_process: prior AR(1) process
"""
assert len(post_zs) > 1, "GP is for time, need more than 1 time step."
assert isinstance(prior_z_process, GaussianProcess), "Must use GP."
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'KL cost' is postive KL divergence
z0_bxu = post_zs[0].sample
logq_bxu = post_zs[0].logp(z0_bxu)
logp_bxu = prior_z_process.logp_t(z0_bxu)
z_tm1_bxu = z0_bxu
for z_t in post_zs[1:]:
# posterior is independent in time, prior is not
z_t_bxu = z_t.sample
logq_bxu += z_t.logp(z_t_bxu)
logp_bxu += prior_z_process.logp_t(z_t_bxu, z_tm1_bxu)
z_tm1 = z_t_bxu
kl_bxu = logq_bxu - logp_bxu
kl_b = tf.reduce_sum(kl_bxu, [1])
self.kl_cost_b = kl_b
self.kl_cost = tf.reduce_mean(kl_b)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""
LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, control inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additionally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
The main data structure being passed around is a dataset. This is a dictionary
of data dictionaries.
DATASET: The top level dictionary is simply name (string -> dictionary).
The nested dictionary is the DATA DICTIONARY, which has the following keys:
'train_data' and 'valid_data', whose values are the corresponding training
and validation data with shape
ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
The data dictionary also has a few more keys:
'train_ext_input' and 'valid_ext_input', if there are know external inputs
to the system being modeled, these take on dimensions:
ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
'alignment_matrix_cxf' - If you are using multiple days data, it's possible
that one can align the channels (see manuscript). If so each dataset will
contain this matrix, which will be used for both the input adapter and the
output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then
one can add three more fields for plotting purposes. These are 'train_truth'
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
'train_data', and 'valid_data' but represent the underlying rates of the
observations. Finally, if one needs to convert scale for plotting the true
underlying firing rates, there is the 'conversion_factor' key.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import tensorflow as tf
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
from distributions import diag_gaussian_log_likelihood
from distributions import KLCost_GaussianGaussian, Poisson
from distributions import LearnableAutoRegressive1Prior
from distributions import KLCost_GaussianGaussianProcessSampled
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
from utils import log_sum_exp, flatten
from plot_lfads import plot_lfads
class GRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
"""
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
clip_value=np.inf, collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with
ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
collections (optional): List of additonal collections variables should
belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._weight_scale = weight_scale
self._clip_value = clip_value
self._collections = collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
if inputs is not None:
xh = tf.concat(axis=1, values=[x, h])
else:
xh = h
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
2 * self._num_units,
alpha=self._weight_scale,
name="xh_2_ru",
collections=self._collections))
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
xrh = tf.concat(axis=1, values=[x, r * h])
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
collections=self._collections))
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class GenGRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
This version is specialized for the generator, but isn't as fast, so
we have two. Note this allows for l2 regularization on the recurrent
weights, but also implicitly rescales the inputs via the 1/sqrt(input)
scaling in the linear helper routine to be large magnitude, if there are
fewer inputs than recurrent state.
"""
def __init__(self, num_units, forget_bias=1.0,
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
input_collections=None, recurrent_collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with
ws being the weight scale.
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs),
with ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
input_collections (optional): List of additonal collections variables
that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables
that rec->rec weights should belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._input_weight_scale = input_weight_scale
self._rec_weight_scale = rec_weight_scale
self._clip_value = clip_value
self._input_collections = input_collections
self._rec_collections = recurrent_collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r_x = u_x = 0.0
if x is not None:
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
2 * self._num_units,
alpha=self._input_weight_scale,
do_bias=False,
name="x_2_ru",
normalized=False,
collections=self._input_collections))
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
2 * self._num_units,
do_bias=True,
alpha=self._rec_weight_scale,
name="h_2_ru",
collections=self._rec_collections))
r = r_x + r_h
u = u_x + u_h
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
c_x = 0.0
if x is not None:
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
alpha=self._input_weight_scale,
normalized=False,
collections=self._input_collections)
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
alpha=self._rec_weight_scale,
collections=self._rec_collections)
c = tf.tanh(c_x + c_rh)
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class LFADS(object):
"""LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additoinally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
"""
def __init__(self, hps, kind="train", datasets=None):
"""Create an LFADS model.
train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative
log-likelihood using IWAE error (Importance Weighed Auto-encoder).
This is the denoising operation.
prior_sample - a model for generation - sampling from priors is used
Args:
hps: The dictionary of hyper parameters.
kind: the type of model to build (see above).
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'prior_sample']
assert kind in all_kinds, 'Wrong kind'
if hps.feedback_factors_or_rates == "rates":
assert len(hps.dataset_names) == 1, \
"Multiple datasets not supported for rate feedback."
num_steps = hps.num_steps
ic_dim = hps.ic_dim
co_dim = hps.co_dim
ext_input_dim = hps.ext_input_dim
cell_class = GRU
gen_cell_class = GenGRU
def makelambda(v): # Used with tf.case
return lambda: v
# Define the data placeholder, and deal with all parts of the graph
# that are dataset dependent.
self.dataName = tf.placeholder(tf.string, shape=())
# The batch_size to be inferred from data, as normal.
# Additionally, the data_dim will be inferred as well, allowing for a
# single placeholder for all datasets, regardless of data dimension.
if hps.output_dist == 'poisson':
# Enforce correct dtype
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
"Data dtype must be int for poisson output distribution"
data_dtype = tf.int32
elif hps.output_dist == 'gaussian':
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
"Data dtype must be float for gaussian output dsitribution"
data_dtype = tf.float32
else:
assert False, "NIY"
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
[None, num_steps, None],
name="data")
self.train_step = tf.get_variable("global_step", [], tf.int64,
tf.zeros_initializer(),
trainable=False)
self.hps = hps
ndatasets = hps.ndatasets
factors_dim = hps.factors_dim
self.preds = preds = [None] * ndatasets
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
self.datasetNames = dataset_names = hps.dataset_names
self.ext_inputs = ext_inputs = None
if len(dataset_names) == 1: # single session
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
used_in_factors_dim = factors_dim
in_identity_if_poss = False
else:
used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
in_identity_if_poss = True
else: # multisession
used_in_factors_dim = factors_dim
in_identity_if_poss = False
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
in_bias_1xf = None
align_bias_1xc = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment bias provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
raise ValueError("""Alignment bias must have dimensions %d
(data_dim), but currently has %d."""%
(data_dim, in_mat_cxf.shape[0]))
if in_mat_cxf is not None and align_bias_1xc is not None:
# (data - alignment_bias) * W_in
# data * W_in - alignment_bias * W_in
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True,
mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations'])
in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b)
with tf.variable_scope("glm"):
out_identity_if_poss = False
if len(dataset_names) == 1 and \
factors_dim == hps.dataset_dims[dataset_names[0]]:
out_identity_if_poss = True
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None:
out_mat_fxc = np.linalg.pinv(in_mat_cxf)
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
identity_if_possible=out_identity_if_poss,
normalized=False,
name="fac_2_logrates_"+name,
collections=['IO_transformations'])
out_fac_W, out_fac_b = out_fac_lin
elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
normalized=False,
name="fac_2_means_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
bias_init_value = np.ones([1, data_dim]).astype(np.float32)
out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=mat_init_value,
bias_init_value=bias_init_value,
normalized=False,
name="fac_2_logvars_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
out_fac_W = tf.concat(
axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
out_fac_b = tf.concat(
axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
else:
assert False, "NIY"
preds[d] = tf.equal(tf.constant(name), self.dataName)
data_dim = hps.dataset_dims[name]
fns_out_fac_Ws[d] = makelambda(out_fac_W)
fns_out_fac_bs[d] = makelambda(out_fac_b)
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
def _case_with_no_default(pairs):
def _default_value_fn():
with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
return tf.identity(pairs[0][1]())
return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = _case_with_no_default(pf_pairs_in_fac_Ws)
this_in_fac_b = _case_with_no_default(pf_pairs_in_fac_bs)
this_out_fac_W = _case_with_no_default(pf_pairs_out_fac_Ws)
this_out_fac_b = _case_with_no_default(pf_pairs_out_fac_bs)
# External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0:
self.ext_input = tf.placeholder(tf.float32,
[None, num_steps, ext_input_dim],
name="ext_input")
else:
self.ext_input = None
ext_input_bxtxi = self.ext_input
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
self.batch_size = batch_size = int(hps.batch_size)
self.learning_rate = tf.Variable(float(hps.learning_rate_init),
trainable=False, name="learning_rate")
self.learning_rate_decay_op = self.learning_rate.assign(
self.learning_rate * hps.learning_rate_decay_factor)
# Dropout the data.
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
if hps.ext_input_dim > 0:
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
else:
ext_input_do_bxtxi = None
# ENCODERS
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
num_steps_to_encode):
"""Encode data for LFADS
Args:
dataset_bxtxd - the data to encode, as a 3 tensor, with dims
time x batch x data dims.
enc_cell: encoder cell
name: name of encoder
forward_or_reverse: string, encode in forward or reverse direction
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
Returns:
encoded data as a list with num_steps_to_encode items, in order
"""
if forward_or_reverse == "forward":
dstr = "_fwd"
time_fwd_or_rev = range(num_steps_to_encode)
else:
dstr = "_rev"
time_fwd_or_rev = reversed(range(num_steps_to_encode))
with tf.variable_scope(name+"_enc"+dstr, reuse=False):
enc_state = tf.tile(
tf.Variable(tf.zeros([1, enc_cell.state_size]),
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape
enc_outs = [None] * num_steps_to_encode
for i, t in enumerate(time_fwd_or_rev):
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
dataset_t_bxd = dataset_bxtxd[:,t,:]
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
in_fac_t_bxf.set_shape([None, used_in_factors_dim])
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
enc_input_t_bxfpe = tf.concat(
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
else:
enc_input_t_bxfpe = in_fac_t_bxf
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
enc_outs[t] = enc_out
return enc_outs
# Encode initial condition means and variances
# ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0)
self.ic_enc_fwd = [None] * num_steps
self.ic_enc_rev = [None] * num_steps
if ic_dim > 0:
enc_ic_cell = cell_class(hps.ic_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "forward",
hps.num_steps_for_gen_ic)
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "reverse",
hps.num_steps_for_gen_ic)
self.ic_enc_fwd = ic_enc_fwd
self.ic_enc_rev = ic_enc_rev
# Encoder control input means and variances, bi-directional encoding so:
# ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t)
self.ci_enc_fwd = [None] * num_steps
self.ci_enc_rev = [None] * num_steps
if co_dim > 0:
enc_ci_cell = cell_class(hps.ci_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "forward",
hps.num_steps)
if hps.do_causal_controller:
ci_enc_rev = None
else:
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "reverse",
hps.num_steps)
self.ci_enc_fwd = ci_enc_fwd
self.ci_enc_rev = ci_enc_rev
# STOCHASTIC LATENT VARIABLES, priors and posteriors
# (initial conditions g0, and control inputs, u_t)
# Note that zs represent all the stochastic latent variables.
with tf.variable_scope("z", reuse=False):
self.prior_zs_g0 = None
self.posterior_zs_g0 = None
self.g0s_val = None
if ic_dim > 0:
self.prior_zs_g0 = \
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
mean_init=0.0,
var_min=hps.ic_prior_var_min,
var_init=hps.ic_prior_var_scale,
var_max=hps.ic_prior_var_max)
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
ic_enc = tf.nn.dropout(ic_enc, keep_prob)
self.posterior_zs_g0 = \
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
var_min=hps.ic_post_var_min)
if kind in ["train", "posterior_sample_and_average"]:
zs_g0 = self.posterior_zs_g0
else:
zs_g0 = self.prior_zs_g0
if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
self.g0s_val = zs_g0.sample
else:
self.g0s_val = zs_g0.mean
# Priors for controller, 'co' for controller output
self.prior_zs_co = prior_zs_co = [None] * num_steps
self.posterior_zs_co = posterior_zs_co = [None] * num_steps
self.zs_co = zs_co = [None] * num_steps
self.prior_zs_ar_con = None
if co_dim > 0:
# Controller outputs
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
self.prior_zs_ar_con = prior_zs_ar_con = \
LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
autocorrelation_taus,
noise_variances,
hps.do_train_prior_ar_atau,
hps.do_train_prior_ar_nvar,
num_steps, "u_prior_ar1")
# CONTROLLER -> GENERATOR -> RATES
# (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) )
self.controller_outputs = u_t = [None] * num_steps
self.con_ics = con_state = None
self.con_states = con_states = [None] * num_steps
self.con_outs = con_outs = [None] * num_steps
self.gen_inputs = gen_inputs = [None] * num_steps
if co_dim > 0:
# gen_cell_class here for l2 penalty recurrent weights
# didn't split the cell_weight scale here, because I doubt it matters
con_cell = gen_cell_class(hps.con_dim,
input_weight_scale=hps.cell_weight_scale,
rec_weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_con_reg'])
with tf.variable_scope("con", reuse=False):
self.con_ics = tf.tile(
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), \
name="c0"),
tf.stack([batch_size, 1]))
self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
con_states[-1] = self.con_ics
gen_cell = gen_cell_class(hps.gen_dim,
input_weight_scale=hps.gen_cell_input_weight_scale,
rec_weight_scale=hps.gen_cell_rec_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_gen_reg'])
with tf.variable_scope("gen", reuse=False):
if ic_dim == 0:
self.gen_ics = tf.tile(
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
tf.stack([batch_size, 1]))
else:
self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
identity_if_possible=True,
name="g0_2_gen_ic")
self.gen_states = gen_states = [None] * num_steps
self.gen_outs = gen_outs = [None] * num_steps
gen_states[-1] = self.gen_ics
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
self.factors = factors = [None] * num_steps
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
self.rates = rates = [None] * num_steps
# rates[-1] is collected to potentially feed back to controller
with tf.variable_scope("glm", reuse=False):
if hps.output_dist == 'poisson':
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
log_rates_t0.set_shape([None, None])
rates[-1] = tf.exp(log_rates_t0) # rate
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[-1] = means_t_bxd
else:
assert False, "NIY"
# We support mulitple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
# dimension. So in the case of gaussian, for example, it'll be
# batch x (D+D), where each D dims is the mean, and then variances,
# respectively. For a distribution with 3 parameters, it would be
# batch x (D+D+D).
self.output_dist_params = dist_params = [None] * num_steps
self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z)
for t in range(num_steps):
# Controller
if co_dim > 0:
# Build inputs for controller
tlag = t - hps.controller_input_lag
if tlag < 0:
con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
else:
con_in_f_t = ci_enc_fwd[tlag]
if hps.do_causal_controller:
# If controller is causal (wrt to data generation process), then it
# cannot see future data. Thus, excluding ci_enc_rev[t] is obvious.
# Less obvious is the need to exclude factors[t-1]. This arises
# because information flows from g0 through factors to the controller
# input. The g0 encoding is backwards, so we must necessarily exclude
# the factors in order to keep the controller input purely from a
# forward encoding (however unlikely it is that
# g0->factors->controller channel might actually be used in this way).
con_in_list_t = [con_in_f_t]
else:
tlag_rev = t + hps.controller_input_lag
if tlag_rev >= num_steps:
# better than zeros
con_in_r_t = tf.zeros_like(ci_enc_rev[0])
else:
con_in_r_t = ci_enc_rev[tlag_rev]
con_in_list_t = [con_in_f_t, con_in_r_t]
if hps.do_feed_factors_to_controller:
if hps.feedback_factors_or_rates == "factors":
con_in_list_t.append(factors[t-1])
elif hps.feedback_factors_or_rates == "rates":
con_in_list_t.append(rates[t-1])
else:
assert False, "NIY"
con_in_t = tf.concat(axis=1, values=con_in_list_t)
con_in_t = tf.nn.dropout(con_in_t, keep_prob)
with tf.variable_scope("con", reuse=True if t > 0 else None):
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
posterior_zs_co[t] = \
DiagonalGaussianFromInput(con_outs[t], co_dim,
name="con_to_post_co")
if kind == "train":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_sample_and_average":
u_t[t] = posterior_zs_co[t].sample
else:
u_t[t] = prior_zs_ar_con.samples_t[t]
# Inputs to the generator (controller output + external input)
if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
if co_dim > 0:
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
else:
gen_inputs[t] = ext_input_t_bxi
else:
gen_inputs[t] = u_t[t]
# Generator
data_t_bxd = dataset_ph[:,t,:]
with tf.variable_scope("gen", reuse=True if t > 0 else None):
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
with tf.variable_scope("gen", reuse=True): # ic defined it above
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
with tf.variable_scope("glm", reuse=True if t > 0 else None):
if hps.output_dist == 'poisson':
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
log_rates_t.set_shape([None, None])
rates[t] = dist_params[t] = tf.exp(log_rates_t) # rates feed back
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[t] = means_t_bxd # rates feed back to controller
dist_params[t] = tf.concat(
axis=1, values=[means_t_bxd, tf.exp(logvars_t_bxd)])
loglikelihood_t = \
diag_gaussian_log_likelihood(data_t_bxd,
means_t_bxd, logvars_t_bxd)
else:
assert False, "NIY"
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
# Correlation of inferred inputs cost.
self.corr_cost = tf.constant(0.0)
if hps.co_mean_corr_scale > 0.0:
all_sum_corr = []
for i in range(hps.co_dim):
for j in range(i+1, hps.co_dim):
sum_corr_ij = tf.constant(0.0)
for t in range(num_steps):
u_mean_t = posterior_zs_co[t].mean
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs
# Variational Lower Bound on posterior, p(z|x), plus reconstruction cost.
# KL and reconstruction costs are normalized only by batch size, not by
# dimension, or by time steps.
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
self.kl_cost = tf.constant(0.0) # VAE KL cost
self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
self.nll_bound_vae = tf.constant(0.0)
self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
if kind in ["train", "posterior_sample_and_average"]:
kl_cost_g0_b = 0.0
kl_cost_co_b = 0.0
if ic_dim > 0:
g0_priors = [self.prior_zs_g0]
g0_posts = [self.posterior_zs_g0]
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
if co_dim > 0:
kl_cost_co_b = \
KLCost_GaussianGaussianProcessSampled(
posterior_zs_co, prior_zs_ar_con).kl_cost_b
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'reconstruction cost' is negative log likelihood
self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
# VAE error averages outside the log
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
# IWAE error averages inside the log
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
self.nll_bound_iwae = -iwae_lb_on_ll
# L2 regularization on the generator, normalized by number of parameters.
self.l2_cost = tf.constant(0.0)
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
l2_costs = []
l2_numels = []
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
tf.get_collection('l2_con_reg')]
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
for v in l2_reg_vars:
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
numel_f = tf.cast(numel, tf.float32)
l2_numels.append(numel_f)
v_l2 = tf.reduce_sum(v*v)
l2_costs.append(0.5 * l2_scale * v_l2)
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
# Compute the cost for training, part of the graph regardless.
# The KL cost can be problematic at the beginning of optimization,
# so we allow an exponential increase in weighting the KL from 0
# to 1.
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
self.kl_weight = kl_weight = \
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
self.l2_weight = l2_weight = \
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
self.timed_kl_cost = kl_weight * self.kl_cost
self.timed_l2_cost = l2_weight * self.l2_cost
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
self.cost = self.recon_cost + self.timed_kl_cost + \
self.timed_l2_cost + self.weight_corr_cost
if kind != "train":
# save every so often
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep_lve)
return
# OPTIMIZATION
if not self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
else:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
total_params = 0
for i in range(len(tvars)):
shape = tvars[i].get_shape().as_list()
print(" ", i, tvars[i].name, shape)
total_params += np.prod(shape)
print("Total model parameters: ", total_params)
grads = tf.gradients(self.cost, tvars)
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-01)
self.grads = grads
self.grad_global_norm = grad_global_norm
self.train_op = opt.apply_gradients(
zip(grads, tvars), global_step=self.train_step)
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# SUMMARIES, used only during training.
# example summary
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
name='image_tensor')
self.example_summ = tf.summary.image("LFADS example", self.example_image,
collections=["example_summaries"])
# general training summaries
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
self.grad_global_norm)
if hps.co_dim > 0:
self.atau_summ = [None] * hps.co_dim
self.pvar_summ = [None] * hps.co_dim
for c in range(hps.co_dim):
self.atau_summ[c] = \
tf.summary.scalar("AR Autocorrelation taus " + str(c),
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
self.pvar_summ[c] = \
tf.summary.scalar("AR Variances " + str(c),
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
# cost summaries, separated into different collections for
# training vs validation. We make placeholders for these, because
# even though the graph computes these costs on a per-batch basis,
# we want to report the more reliable metric of per-epoch cost.
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
collections=["train_summaries"])
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
collections=["valid_summaries"])
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
collections=["train_summaries"])
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
recon_cost_ph,
collections=["train_summaries"])
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
recon_cost_ph,
collections=["valid_summaries"])
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
collections=["train_summaries"])
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
collections=["valid_summaries"])
self.kl_cost_ph = kl_cost_ph
self.l2_cost_ph = l2_cost_ph
self.recon_cost_ph = recon_cost_ph
self.total_cost_ph = total_cost_ph
# Merged summaries, for easy coding later.
self.merged_examples = tf.summary.merge_all(key="example_summaries")
self.merged_generic = tf.summary.merge_all() # default key is 'summaries'
self.merged_train = tf.summary.merge_all(key="train_summaries")
self.merged_valid = tf.summary.merge_all(key="valid_summaries")
session = tf.get_default_session()
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
self.writer = tf.summary.FileWriter(self.logfile)
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
keep_prob=None):
"""Build the feed dictionary, handles cases where there is no value defined.
Args:
train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices.
data_bxtxd: The data tensor
ext_input_bxtxi (optional): The external input tensor
keep_prob: The drop out keep probability.
Returns:
The feed dictionary with TF tensors as keys and data as values, for use
with tf.Session.run()
"""
feed_dict = {}
B, T, _ = data_bxtxd.shape
feed_dict[self.dataName] = train_name
feed_dict[self.dataset_ph] = data_bxtxd
if self.ext_input is not None and ext_input_bxtxi is not None:
feed_dict[self.ext_input] = ext_input_bxtxi
if keep_prob is None:
feed_dict[self.keep_prob] = self.hps.keep_prob
else:
feed_dict[self.keep_prob] = keep_prob
return feed_dict
@staticmethod
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
example_idxs=None):
"""Get a batch of data, either randomly chosen, or specified directly.
Args:
data_extxd: The data to model, numpy tensors with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions
batch_size: The size of the batch to return
example_idxs (optional): The example indices used to select examples.
Returns:
A tuple with two parts:
1. Batched data numpy tensor with shape:
batch_size x # time steps x # dimensions
2. Batched external input numpy tensor with shape:
batch_size x # time steps x # external input dims
"""
assert batch_size is not None or example_idxs is not None, "Problems"
E, T, D = data_extxd.shape
if example_idxs is None:
example_idxs = np.random.choice(E, batch_size)
ext_input_bxtxi = None
if ext_input_extxi is not None:
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
return data_extxd[example_idxs,:,:], ext_input_bxtxi
@staticmethod
def example_idxs_mod_batch_size(nexamples, batch_size):
"""Given a number of examples, E, and a batch_size, B, generate indices
[0, 1, 2, ... B-1;
[B, B+1, ... 2*B-1;
...
]
returning those indices as a 2-dim tensor shaped like E/B x B. Note that
shape is only correct if E % B == 0. If not, then an extra row is generated
so that the remainder of examples is included. The extra examples are
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
for randomized behavior.
Args:
nexamples: The number of examples to batch up.
batch_size: The size of the batch.
Returns:
2-dim tensor as described above.
"""
bmrem = batch_size - (nexamples % batch_size)
bmrem_examples = []
if bmrem < batch_size:
#bmrem_examples = np.zeros(bmrem, dtype=np.int32)
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
bmrem_examples = np.sort(ridxs)
example_idxs = range(nexamples) + list(bmrem_examples)
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
@staticmethod
def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
"""Indices 1:nexamples, randomized, in 2D form of
shape = (nexamples / batch_size) x batch_size. The remainder
is managed by drawing randomly from 1:nexamples.
Args:
nexamples: number of examples to randomize
batch_size: number of elements in batch
Returns:
The randomized, properly shaped indicies.
"""
assert nexamples > batch_size, "Problems"
bmrem = batch_size - nexamples % batch_size
bmrem_examples = []
if bmrem < batch_size:
bmrem_examples = np.random.choice(range(nexamples),
size=bmrem, replace=False)
example_idxs = range(nexamples) + list(bmrem_examples)
mixed_example_idxs = np.random.permutation(example_idxs)
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
def shuffle_spikes_in_time(self, data_bxtxd):
"""Shuffle the spikes in the temporal dimension. This is useful to
help the LFADS system avoid overfitting to individual spikes or fast
oscillations found in the data that are irrelevant to behavior. A
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
enough to pick up dynamics that you may not want.
Args:
data_bxtxd: numpy array of spike count data to be shuffled.
Returns:
S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately.
"""
B, T, N = data_bxtxd.shape
w = self.hps.temporal_spike_jitter_width
if w == 0:
return data_bxtxd
max_counts = np.max(data_bxtxd)
S_bxtxd = np.zeros([B,T,N])
# Intuitively, shuffle spike occurances, 0 or 1, but since we have counts,
# Do it over and over again up to the max count.
for mc in range(1,max_counts+1):
idxs = np.nonzero(data_bxtxd >= mc)
data_ones = np.zeros_like(data_bxtxd)
data_ones[data_bxtxd >= mc] = 1
nfound = len(idxs[0])
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
shuffle_tidxs = idxs[1].copy()
shuffle_tidxs += shuffles_incrs_in_time
# Reflect on the boundaries to not lose mass.
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
shuffle_tidxs[shuffle_tidxs > T-1] = \
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
S_bxtxd[iii] += 1
return S_bxtxd
def shuffle_and_flatten_datasets(self, datasets, kind='train'):
"""Since LFADS supports multiple datasets in the same dynamical model,
we have to be careful to use all the data in a single training epoch. But
since the datasets my have different data dimensionality, we cannot batch
examples from data dictionaries together. Instead, we generate random
batches within each data dictionary, and then randomize these batches
while holding onto the dataname, so that when it's time to feed
the graph, the correct in/out matrices can be selected, per batch.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
kind: 'train' or 'valid'
Returns:
A flat list, in which each element is a pair ('name', indices).
"""
batch_size = self.hps.batch_size
ndatasets = len(datasets)
random_example_idxs = {}
epoch_idxs = {}
all_name_example_idx_pairs = []
kind_data = kind + '_data'
for name, data_dict in datasets.items():
nexamples, ntime, data_dim = data_dict[kind_data].shape
epoch_idxs[name] = 0
random_example_idxs, _ = \
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
epoch_size = random_example_idxs.shape[0]
names = [name] * epoch_size
all_name_example_idx_pairs += zip(names, random_example_idxs)
np.random.shuffle(all_name_example_idx_pairs) # shuffle in place
return all_name_example_idx_pairs
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
"""Train the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional): The batch_size to use
do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch?
Returns:
A tuple with 6 float values:
(total cost of the epoch, epoch reconstruction cost,
epoch kl cost, KL weight used this training epoch,
total l2 cost on generator, and the corresponding weight).
"""
ops_to_eval = [self.cost, self.recon_cost,
self.kl_cost, self.kl_weight,
self.l2_cost, self.l2_weight,
self.train_op]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
kl_weight = collected_op_values[-1][3]
l2_cost = collected_op_values[-1][4]
l2_weight = collected_op_values[-1][5]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
if do_save_ckpt:
session = tf.get_default_session()
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '.ckpt')
self.seso_saver.save(session, checkpoint_path,
global_step=self.train_step)
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
kl_weight, l2_cost, l2_weight
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
do_collect=True, keep_prob=None):
"""Run the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call.
batch_size (optional): The batch_size to use
do_collect (optional): Should the routine collect all session.run
output as a list, and return it?
keep_prob (optional): The dropout keep probability.
Returns:
A list of lists, the internal list is the return for the ops for each
session.run() call. The outer list collects over the epoch.
"""
hps = self.hps
all_name_example_idx_pairs = \
self.shuffle_and_flatten_datasets(datasets, kind)
kind_data = kind + '_data'
kind_ext_input = kind + '_ext_input'
total_cost = total_recon_cost = total_kl_cost = 0.0
session = tf.get_default_session()
epoch_size = len(all_name_example_idx_pairs)
evaled_ops_list = []
for name, example_idxs in all_name_example_idx_pairs:
data_dict = datasets[name]
data_extxd = data_dict[kind_data]
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
data_extxd = self.shuffle_spikes_in_time(data_extxd)
ext_input_extxi = data_dict[kind_ext_input]
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
example_idxs=example_idxs)
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
keep_prob=keep_prob)
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
if do_collect:
evaled_ops_list.append(evaled_ops_np)
return evaled_ops_list
def summarize_all(self, datasets, summary_values):
"""Plot and summarize stuff in tensorboard.
Note that everything done in the current function is otherwise done on
a single, randomly selected dataset (except for summary_values, which are
passed in.)
Args:
datasets, the dictionary of datasets used in the study.
summary_values: These summary values are created from the training loop,
and so summarize the entire set of datasets.
"""
hps = self.hps
tr_kl_cost = summary_values['tr_kl_cost']
tr_recon_cost = summary_values['tr_recon_cost']
tr_total_cost = summary_values['tr_total_cost']
kl_weight = summary_values['kl_weight']
l2_weight = summary_values['l2_weight']
l2_cost = summary_values['l2_cost']
has_any_valid_set = summary_values['has_any_valid_set']
i = summary_values['nepochs']
session = tf.get_default_session()
train_summ, train_step = session.run([self.merged_train,
self.train_step],
feed_dict={self.l2_cost_ph:l2_cost,
self.kl_cost_ph:tr_kl_cost,
self.recon_cost_ph:tr_recon_cost,
self.total_cost_ph:tr_total_cost})
self.writer.add_summary(train_summ, train_step)
if has_any_valid_set:
ev_kl_cost = summary_values['ev_kl_cost']
ev_recon_cost = summary_values['ev_recon_cost']
ev_total_cost = summary_values['ev_total_cost']
eval_summ = session.run(self.merged_valid,
feed_dict={self.kl_cost_ph:ev_kl_cost,
self.recon_cost_ph:ev_recon_cost,
self.total_cost_ph:ev_total_cost})
self.writer.add_summary(eval_summ, train_step)
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight)
else:
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, tr_recon_cost,
tr_kl_cost, l2_cost, kl_weight, l2_weight)
if self.hps.csv_log:
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
with open(csv_file, "a") as myfile:
myfile.write(csv_outstr)
def plot_single_example(self, datasets):
"""Plot an image relating to a randomly chosen, specific example. We use
posterior sample and average by taking one example, and filling a whole
batch with that example, sample from the posterior, and then average the
quantities.
"""
hps = self.hps
all_data_names = datasets.keys()
data_name = np.random.permutation(all_data_names)[0]
data_dict = datasets[data_name]
has_valid_set = True if data_dict['valid_data'] is not None else False
cf = 1.0 # plotting concern
# posterior sample and average here
E, _, _ = data_dict['train_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
train_data_bxtxd, train_ext_input_bxtxi = \
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
example_idxs=example_idxs)
truth_train_data_bxtxd = None
if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
example_idxs=example_idxs)
cf = data_dict['conversion_factor']
# plotter does averaging
train_model_values = self.eval_model_runs_batch(data_name,
train_data_bxtxd,
train_ext_input_bxtxi,
do_average_batch=False)
train_step = train_model_values['train_steps']
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
train_ext_input_bxtxi, keep_prob=1.0)
session = tf.get_default_session()
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
self.writer.add_summary(generic_summ, train_step)
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
truth_valid_data_bxtxd = None
if has_valid_set:
E, _, _ = data_dict['valid_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
valid_data_bxtxd, valid_ext_input_bxtxi = \
self.get_batch(data_dict['valid_data'],
data_dict['valid_ext_input'],
example_idxs=example_idxs)
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
example_idxs=example_idxs)
else:
truth_valid_data_bxtxd = None
# plotter does averaging
valid_model_values = self.eval_model_runs_batch(data_name,
valid_data_bxtxd,
valid_ext_input_bxtxi,
do_average_batch=False)
example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
train_model_vals=train_model_values,
train_ext_input_bxtxi=train_ext_input_bxtxi,
train_truth_bxtxd=truth_train_data_bxtxd,
valid_bxtxd=valid_data_bxtxd,
valid_model_vals=valid_model_values,
valid_ext_input_bxtxi=valid_ext_input_bxtxi,
valid_truth_bxtxd=truth_valid_data_bxtxd,
bidx=None, cf=cf, output_dist=hps.output_dist)
example_image = np.expand_dims(example_image, axis=0)
example_summ = session.run(self.merged_examples,
feed_dict={self.example_image : example_image})
self.writer.add_summary(example_summ)
def train_model(self, datasets):
"""Train the model, print per-epoch information, and save checkpoints.
Loop over training epochs. The function that actually does the
training is train_epoch. This function iterates over the training
data, one epoch at a time. The learning rate schedule is such
that it will stay the same until the cost goes up in comparison to
the last few values, then it will drop.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
hps = self.hps
has_any_valid_set = False
for data_dict in datasets.values():
if data_dict['valid_data'] is not None:
has_any_valid_set = True
break
session = tf.get_default_session()
lr = session.run(self.learning_rate)
lr_stop = hps.learning_rate_stop
i = -1
train_costs = []
valid_costs = []
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
lowest_ev_cost = np.Inf
while True:
i += 1
do_save_ckpt = True if i % 10 ==0 else False
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
# Evaluate the validation cost, and potentially save. Note that this
# routine will not save a validation checkpoint until the kl weight and
# l2 weights are equal to 1.0.
if has_any_valid_set:
ev_total_cost, ev_recon_cost, ev_kl_cost = \
self.eval_cost_epoch(datasets, kind='valid')
valid_costs.append(ev_total_cost)
# > 1 may give more consistent results, but not the actual lowest vae.
# == 1 gives the lowest vae seen so far.
n_lve = 1
run_avg_lve = np.mean(valid_costs[-n_lve:])
# conditions for saving checkpoints:
# KL weight must have finished stepping (>=1.0), AND
# L2 weight must have finished stepping OR L2 is not being used, AND
# the current run has a lower LVE than previous runs AND
# len(valid_costs > n_lve) (not sure what that does)
if kl_weight >= 1.0 and \
(l2_weight >= 1.0 or \
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
lowest_ev_cost = run_avg_lve
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '_lve.ckpt')
self.lve_saver.save(session, checkpoint_path,
global_step=self.train_step,
latest_filename='checkpoint_lve')
# Plot and summarize.
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
'l2_weight':l2_weight, 'kl_weight':kl_weight,
'l2_cost':l2_cost}
self.summarize_all(datasets, values)
self.plot_single_example(datasets)
# Manage learning rate.
train_res = tr_total_cost
n_lr = hps.learning_rate_n_to_compare
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
_ = session.run(self.learning_rate_decay_op)
lr = session.run(self.learning_rate)
print(" Decreasing learning rate to %f." % lr)
# Force the system to run n_lr times while at this lr.
train_costs.append(np.inf)
else:
train_costs.append(train_res)
if lr < lr_stop:
print("Stopping optimization based on learning rate criteria.")
break
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
batch_size=None):
"""Evaluate the cost of the epoch.
Args:
data_dict: The dictionary of data (training and validation) used for
training and evaluation of the model, respectively.
Returns:
a 3 tuple of costs:
(epoch total cost, epoch reconstruction cost, epoch KL cost)
"""
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
keep_prob=1.0)
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
do_eval_cost=False, do_average_batch=False):
"""Returns all the goodies for the entire model, per batch.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_bxtxd: Numpy array training data with shape:
batch_size x # time steps x # dimensions
ext_input_bxtxi: Numpy array training external input with shape:
batch_size x # time steps x # external input dims
do_eval_cost (optional): If true, the IWAE (Importance Weighted
Autoencoder) log likeihood bound, instead of the VAE version.
do_average_batch (optional): average over the batch, useful for getting
good IWAE costs, and model outputs for a single data point.
Returns:
A dictionary with the outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the rates.
"""
session = tf.get_default_session()
feed_dict = self.build_feed_dict(data_name, data_bxtxd,
ext_input_bxtxi, keep_prob=1.0)
# Non-temporal signals will be batch x dim.
# Temporal signals are list length T with elements batch x dim.
tf_vals = [self.gen_ics, self.gen_states, self.factors,
self.output_dist_params]
tf_vals.append(self.cost)
tf_vals.append(self.nll_bound_vae)
tf_vals.append(self.nll_bound_iwae)
tf_vals.append(self.train_step) # not train_op!
if self.hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
if self.hps.co_dim > 0:
tf_vals.append(self.controller_outputs)
tf_vals_flat, fidxs = flatten(tf_vals)
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
if self.hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if self.hps.co_dim > 0:
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
nll_bound_vaes = nll_bound_vaes[0]
nll_bound_iwaes = nll_bound_iwaes[0]
train_steps = train_steps[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
if self.hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
post_g0_mean = post_g0_mean[0]
post_g0_logvar = post_g0_logvar[0]
if self.hps.co_dim > 0:
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
if do_average_batch:
gen_ics = np.mean(gen_ics, axis=0)
gen_states = np.mean(gen_states, axis=0)
factors = np.mean(factors, axis=0)
out_dist_params = np.mean(out_dist_params, axis=0)
if self.hps.ic_dim > 0:
prior_g0_mean = np.mean(prior_g0_mean, axis=0)
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
post_g0_mean = np.mean(post_g0_mean, axis=0)
post_g0_logvar = np.mean(post_g0_logvar, axis=0)
if self.hps.co_dim > 0:
controller_outputs = np.mean(controller_outputs, axis=0)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = out_dist_params
model_vals['costs'] = costs
model_vals['nll_bound_vaes'] = nll_bound_vaes
model_vals['nll_bound_iwaes'] = nll_bound_iwaes
model_vals['train_steps'] = train_steps
if self.hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
model_vals['post_g0_mean'] = post_g0_mean
model_vals['post_g0_logvar'] = post_g0_logvar
if self.hps.co_dim > 0:
model_vals['controller_outputs'] = controller_outputs
return model_vals
def eval_model_runs_avg_epoch(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns all the expected value for goodies for the entire model.
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs. The expected value is approximate, and
accomplished via sampling (batch_size) samples for every examples.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
for es_idx in range(E_to_process):
print("Running %d of %d." % (es_idx+1, E_to_process))
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=example_idxs)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=True)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def write_model_runs(self, datasets, output_fname=None):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
"""
hps = self.hps
kind = hps.kind
for data_name, data_dict in datasets.items():
data_tuple = [('train', data_dict['train_data'],
data_dict['train_ext_input']),
('valid', data_dict['valid_data'],
data_dict['valid_ext_input'])]
for data_kind, data_extxd, ext_input_extxi in data_tuple:
if not output_fname:
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
else:
fname = output_fname + data_name + '_' + data_kind + '_' + kind
print("Writing data for %s data and kind %s." % (data_name, data_kind))
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
ext_input_extxi)
full_fname = os.path.join(hps.lfads_save_dir, fname)
write_data(full_fname, model_runs, compression='gzip')
print("Done.")
def write_model_samples(self, dataset_name, output_fname=None):
"""Use the prior distribution to generate batch_size number of samples
from the model.
LFADS generates a number of outputs for each sample, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from.
output_fname: The name of the file in which to save the generated
samples.
"""
hps = self.hps
batch_size = hps.batch_size
print("Generating %d samples" % (batch_size))
tf_vals = [self.factors, self.gen_states, self.gen_ics,
self.cost, self.output_dist_params]
if hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
if hps.co_dim > 0:
tf_vals += [self.prior_zs_ar_con.samples_t]
tf_vals_flat, fidxs = flatten(tf_vals)
session = tf.get_default_session()
feed_dict = {}
feed_dict[self.dataName] = dataset_name
feed_dict[self.keep_prob] = 1.0
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.co_dim > 0:
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
if hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
if hps.co_dim > 0:
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = output_dist_params
model_vals['costs'] = costs.reshape(1)
if hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
if hps.co_dim > 0:
model_vals['prior_zs_ar_con'] = prior_zs_ar_con
full_fname = os.path.join(hps.lfads_save_dir, output_fname)
write_data(full_fname, model_vals, compression='gzip')
print("Done.")
@staticmethod
def eval_model_parameters(use_nested=True, include_strs=None):
"""Evaluate and return all of the TF variables in the model.
Args:
use_nested (optional): For returning values, use a nested dictoinary, based
on variable scoping, or return all variables in a flat dictionary.
include_strs (optional): A list of strings to use as a filter, to reduce the
number of variables returned. A variable name must contain at least one
string in include_strs as a sub-string in order to be returned.
Returns:
The parameters of the model. This can be in a flat
dictionary, or a nested dictionary, where the nesting is by variable
scope.
"""
all_tf_vars = tf.global_variables()
session = tf.get_default_session()
all_tf_vars_eval = session.run(all_tf_vars)
vars_dict = {}
strs = ["LFADS"]
if include_strs:
strs += include_strs
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval))
e = np.array(var_eval)
print(e, type(e))
else:
e = var_eval
vars_dict[var.name] = e
if not use_nested:
return vars_dict
var_names = vars_dict.keys()
nested_vars_dict = {}
current_dict = nested_vars_dict
for v, var_name in enumerate(var_names):
var_split_name_list = var_name.split('/')
split_name_list_len = len(var_split_name_list)
current_dict = nested_vars_dict
for p, part in enumerate(var_split_name_list):
if p < split_name_list_len - 1:
if part in current_dict:
current_dict = current_dict[part]
else:
current_dict[part] = {}
current_dict = current_dict[part]
else:
current_dict[part] = vars_dict[var_name]
return nested_vars_dict
@staticmethod
def spikify_rates(rates_bxtxd):
"""Randomly spikify underlying rates according a Poisson distribution
Args:
rates_bxtxd: a numpy tensor with shape:
Returns:
A numpy array with the same shape as rates_bxtxd, but with the event
counts.
"""
B,T,N = rates_bxtxd.shape
assert all([B > 0, N > 0]), "problems"
# Because the rates are changing, there is nesting
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
for b in range(B):
for t in range(T):
for n in range(N):
rate = rates_bxtxd[b,t,n]
count = np.random.poisson(rate)
spikes_bxtxd[b,t,n] = count
return spikes_bxtxd
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
def _plot_item(W, name, full_name, nspaces):
plt.figure()
if W.shape == ():
print(name, ": ", W)
elif W.shape[0] == 1:
plt.stem(W.T)
plt.title(full_name)
elif W.shape[1] == 1:
plt.stem(W)
plt.title(full_name)
else:
plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
plt.colorbar()
plt.title(full_name)
def all_plot(d, full_name="", exclude="", nspaces=0):
"""Recursively plot all the LFADS model parameters in the nested
dictionary."""
for k, v in d.iteritems():
this_name = full_name+"/"+k
if isinstance(v, dict):
all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
else:
if exclude == "" or exclude not in this_name:
_plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)
def plot_priors():
g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
g0s_prior_var_bxn = train_modelvals['prior_g0_var']
g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
g0s_post_var_bxn = train_modelvals['posterior_g0_var']
plt.figure(figsize=(10,4), tight_layout=True);
plt.subplot(1,2,1)
plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Mean Values')
plt.subplot(1,2,2)
plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
plt.title('Histogram of Prior/Posterior Log Variance Values')
plt.figure(figsize=(10,10), tight_layout=True)
plt.subplot(2,2,1)
plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 means')
plt.subplot(2,2,2)
plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 means');
plt.subplot(2,2,3)
plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Prior g0 variance Values')
plt.subplot(2,2,4)
plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
plt.colorbar(fraction=0.025, pad=0.04)
plt.title('Posterior g0 variance Values')
plt.figure(figsize=(10,5))
plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
plt.title('Log standard deviation of h0 means');
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
color='r', title=None):
if bidx is None:
vals_txn = np.mean(vals_bxtxn, axis=0)
else:
vals_txn = vals_bxtxn[bidx,:,:]
T, N = vals_txn.shape
if n_to_plot > N:
n_to_plot = N
plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
color=color, lw=1.0)
plt.axis('tight')
if title:
plt.title(title)
def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
truth_bxtxn=None, bidx=None, output_dist="poisson",
conversion_factor=1.0, subplot_cidx=0,
col_title=None):
n_to_plot = 10
scale = 1.0
nrows = 7
plt.subplot(nrows,2,1+subplot_cidx)
if output_dist == 'poisson':
rates = means = conversion_factor * model_vals['output_dist_params']
plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " rates (LFADS - red, Truth - black)")
elif output_dist == 'gaussian':
means_vars = model_vals['output_dist_params']
means, vars = np.split(means_vars,2, axis=2) # bxtxn
stds = np.sqrt(vars)
plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
title=col_title + " means (LFADS - red, Truth - black)")
plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
color='c')
else:
assert 'NIY'
if truth_bxtxn is not None:
plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
scale=scale)
input_title = ""
if "controller_outputs" in model_vals.keys():
input_title += " Controller Output"
plt.subplot(nrows,2,3+subplot_cidx)
u_t = model_vals['controller_outputs'][0:-1]
plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
title=col_title + input_title)
if ext_input_bxtxi is not None:
input_title += " External Input"
plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
scale=scale, title=col_title + input_title)
plt.subplot(nrows,2,5+subplot_cidx)
plot_time_series(means, bidx,
n_to_plot=n_to_plot, scale=1.0,
title=col_title + " Spikes (LFADS - red, Spikes - black)")
plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)
plt.subplot(nrows,2,7+subplot_cidx)
plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
scale=2.0, title=col_title + " Factors")
plt.subplot(nrows,2,9+subplot_cidx)
plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
color='g', scale=1.0, title=col_title + " Generator State")
if bidx is not None:
data_nxt = data_bxtxn[bidx,:,:].T
params_nxt = model_vals['output_dist_params'][bidx,:,:].T
else:
data_nxt = np.mean(data_bxtxn, axis=0).T
params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
if output_dist == 'poisson':
means_nxt = params_nxt
elif output_dist == 'gaussian': # (means+vars) x time
means_nxt = np.vsplit(params_nxt,2)[0] # get means
else:
assert "NIY"
plt.subplot(nrows,2,11+subplot_cidx)
plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Data')
plt.subplot(nrows,2,13+subplot_cidx)
plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
plt.title(col_title + ' Means')
def plot_lfads(train_bxtxd, train_model_vals,
train_ext_input_bxtxi=None, train_truth_bxtxd=None,
valid_bxtxd=None, valid_model_vals=None,
valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
bidx=None, cf=1.0, output_dist='poisson'):
# Plotting
f = plt.figure(figsize=(18,20), tight_layout=True)
plot_lfads_timeseries(train_bxtxd, train_model_vals,
train_ext_input_bxtxi,
truth_bxtxn=train_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist, col_title='Train')
plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
valid_ext_input_bxtxi,
truth_bxtxn=valid_truth_bxtxd,
conversion_factor=cf, bidx=bidx,
output_dist=output_dist,
subplot_cidx=1, col_title='Valid')
# Convert from figure to an numpy array width x height x 3 (last for RGB)
f.canvas.draw()
data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data_wxhx3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from lfads import LFADS
import numpy as np
import os
import tensorflow as tf
import re
import utils
# Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags
# session.
CHECKPOINT_PB_LOAD_NAME = "checkpoint"
CHECKPOINT_NAME = "lfads_vae"
CSV_LOG = "fitlog"
OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all
EXT_INPUT_DIM = 0
IC_DIM = 64
FACTORS_DIM = 50
IC_ENC_DIM = 128
GEN_DIM = 200
GEN_CELL_INPUT_WEIGHT_SCALE = 1.0
GEN_CELL_REC_WEIGHT_SCALE = 1.0
CELL_WEIGHT_SCALE = 1.0
BATCH_SIZE = 128
LEARNING_RATE_INIT = 0.01
LEARNING_RATE_DECAY_FACTOR = 0.95
LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
# Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0
CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/"
CO_DIM = 1
DO_CAUSAL_CONTROLLER = False
DO_FEED_FACTORS_TO_CONTROLLER = True
CONTROLLER_INPUT_LAG = 1
PRIOR_AR_AUTOCORRELATION = 10.0
PRIOR_AR_PROCESS_VAR = 0.1
DO_TRAIN_PRIOR_AR_ATAU = True
DO_TRAIN_PRIOR_AR_NVAR = True
CI_ENC_DIM = 128
CON_DIM = 128
CO_PRIOR_VAR_SCALE = 0.1
KL_INCREASE_STEPS = 2000
L2_INCREASE_STEPS = 2000
L2_GEN_SCALE = 2000.0
L2_CON_SCALE = 0.0
# scale of regularizer on time correlation of inferred inputs
CO_MEAN_CORR_SCALE = 0.0
KL_IC_WEIGHT = 1.0
KL_CO_WEIGHT = 1.0
KL_START_STEP = 0
L2_START_STEP = 0
IC_PRIOR_VAR_MIN = 0.1
IC_PRIOR_VAR_SCALE = 0.1
IC_PRIOR_VAR_MAX = 0.1
IC_POST_VAR_MIN = 0.0001 # protection from KL blowing up
flags = tf.app.flags
flags.DEFINE_string("kind", "train",
"Type of model to build {train, \
posterior_sample_and_average, \
prior_sample, write_model_params")
flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
"Type of output distribution, 'poisson' or 'gaussian'")
flags.DEFINE_boolean("allow_gpu_growth", False,
"If true, only allocate amount of memory needed for \
Session. Otherwise, use full GPU memory.")
# DATA
flags.DEFINE_string("data_dir", DATA_DIR, "Data for training")
flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM,
"Filename stem for data dictionaries.")
flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir")
flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME,
"Name of checkpoint files, use 'checkpoint_lve' for best \
error")
flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME,
"Name of checkpoint files (.ckpt appended)")
flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM,
"Name of output file (postfix will be added)")
flags.DEFINE_string("device", DEVICE,
"Which device to use (default: \"gpu:0\", can also be \
\"cpu:0\", \"gpu:1\", etc)")
flags.DEFINE_string("csv_log", CSV_LOG,
"Name of file to keep running log of fit likelihoods, \
etc (.csv appended)")
flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP,
"Max # of checkpoints to keep (rolling)")
flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS,
"Number of examples to process for posterior sample and \
average (not number of samples to average over).")
flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE,
"Max # of checkpoints to keep for lowest validation error \
models (rolling)")
flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \
inputs")
flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC,
"Number of steps to train the generator initial conditon.")
# If there are observed inputs, there are two ways to add that observed
# input to the model. The first is by treating as something to be
# inferred, and thus encoding the observed input via the encoders, and then
# input to the generator via the "inferred inputs" channel. Second, one
# can input the input directly into the generator. This has the downside
# of making the generation process strictly dependent on knowing the
# observed input for any generated trial.
flags.DEFINE_boolean("inject_ext_input_to_gen",
INJECT_EXT_INPUT_TO_GEN,
"Should observed inputs be input to model via encoders, \
or injected directly into generator?")
# CELL
# The combined recurrent and input weights of the encoder and
# controller cells are by default set to scale at ws/sqrt(#inputs),
# with ws=1.0. You can change this scaling with this parameter.
flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
# GENERATION
# Note that the dimension of the initial conditions is separated from the
# dimensions of the generator initial conditions (and a linear matrix will
# adapt the shapes if necessary). This is just another way to control
# complexity. In all likelihood, setting the ic dims to the size of the
# generator hidden state is just fine.
flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0")
# Setting the dimensions of the factors to something smaller than the data
# dimension is a way to get a reduced dimensionality representation of your
# data.
flags.DEFINE_integer("factors_dim", FACTORS_DIM,
"Number of factors from generator")
flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM,
"Cell hidden size, encoder of h0")
# Controlling the size of the generator is one way to control complexity of
# the dynamics (there is also l2, which will squeeze out unnecessary
# dynamics also). The modern deep learning approach is to make these cells
# as large as tolerable (from a waiting perspective), and then regularize
# them to death with drop out or whatever. I don't know if this is correct
# for the LFADS application or not.
flags.DEFINE_integer("gen_dim", GEN_DIM,
"Cell hidden size, generator.")
# The weights of the generator cell by default set to scale at
# ws/sqrt(#inputs), with ws=1.0. You can change ws for
# the input weights or the recurrent weights with these hyperparameters.
flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE,
"Input scaling for rec weights in generator.")
# KL DISTRIBUTIONS
# If you don't know what you are donig here, please leave alone, the
# defaults should be fine for most cases, irregardless of other parameters.
#
# If you don't want the prior variance to be learned, set the
# following values to the same thing: ic_prior_var_min,
# ic_prior_var_scale, ic_prior_var_max. The prior mean will be
# learned regardless.
flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN,
"Minimum variance in posterior h0 codes.")
flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE,
"Variance of ic prior distribution")
flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX,
"Maximum variance of IC prior distribution.")
# If you really want to limit the information from encoder to decoder,
# Increase ic_post_var_min above 0.0.
flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN,
"Minimum variance of IC posterior distribution.")
flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
"Variance of control input prior distribution.")
flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
"Initial autocorrelation of AR(1) priors.")
flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
"Initial noise variance for AR(1) priors.")
flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
"Is the value for atau an init, or the constant value?")
flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR,
"Is the value for noise variance an init, or the constant \
value?")
# CONTROLLER
# This parameter critically controls whether or not there is a controller
# (along with controller encoders placed into the LFADS graph. If CO_DIM >
# 1, that means there is a 1 dimensional controller outputs, if equal to 0,
# then no controller.
flags.DEFINE_integer("co_dim", CO_DIM,
"Number of control net outputs (>0 builds that graph).")
# The controller will be more powerful if it can see the encoding of the entire
# trial. However, this allows the controller to create inferred inputs that are
# acausal with respect to the actual data generation process. E.g. the data
# generator could have an input at time t, but the controller, after seeing the
# entirety of the trial could infer that the input is coming a little before
# time t, because there are no restrictions on the data the controller sees.
# One can force the controller to be causal (with respect to perturbations in
# the data generator) so that it only sees forward encodings of the data at time
# t that originate at times before or at time t. One can also control the data
# the controller sees by using an input lag (forward encoding at time [t-tlag]
# for controller input at time t. The same can be done in the reverse direction
# (controller input at time t from reverse encoding at time [t+tlag], in the
# case of an acausal controller). Setting this lag > 0 (even lag=1) can be a
# powerful way of avoiding very spiky decodes. Finally, one can manually control
# whether the factors at time t-1 are fed to the controller at time t.
#
# If you don't care about any of this, and just want to smooth your data, set
# do_causal_controller = False
# do_feed_factors_to_controller = True
# causal_input_lag = 0
flags.DEFINE_boolean("do_causal_controller",
DO_CAUSAL_CONTROLLER,
"Restrict the controller create only causal inferred \
inputs?")
# Strictly speaking, feeding either the factors or the rates to the controller
# violates causality, since the g0 gets to see all the data. This may or may not
# be only a theoretical concern.
flags.DEFINE_boolean("do_feed_factors_to_controller",
DO_FEED_FACTORS_TO_CONTROLLER,
"Should factors[t-1] be input to controller at time t?")
flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
"Feedback the factors or the rates to the controller? \
Acceptable values: 'factors' or 'rates'.")
flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
"Time lag on the encoding to controller t-lag for \
forward, t+lag for reverse.")
flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM,
"Cell hidden size, encoder of control inputs")
flags.DEFINE_integer("con_dim", CON_DIM,
"Cell hidden size, controller")
# OPTIMIZATION
flags.DEFINE_integer("batch_size", BATCH_SIZE,
"Batch size to use during training.")
flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT,
"Learning rate initial value")
flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR,
"Learning rate decay, decay by this fraction every so \
often.")
flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP,
"The lr is adaptively reduced, stop training at this value.")
# Rather put the learning rate on an exponentially decreasiong schedule,
# the current algorithm pays attention to the learning rate, and if it
# isn't regularly decreasing, it will decrease the learning rate. So far,
# it works fine, though it is not perfect.
flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE,
"Number of previous costs current cost has to be worse \
than, to lower learning rate.")
# This sets a value, above which, the gradients will be clipped. This hp
# is extremely useful to avoid an infrequent, but highly pathological
# problem whereby the gradient is so large that it destroys the
# optimziation by setting parameters too large, leading to a vicious cycle
# that ends in NaNs. If it's too large, it's useless, if it's too small,
# it essentially becomes the learning rate. It's pretty insensitive, though.
flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM,
"Max norm of gradient before clipping.")
# If your optimizations start "NaN-ing out", reduce this value so that
# the values of the network don't grow out of control. Typically, once
# this parameter is set to a reasonable value, one stops having numerical
# problems.
flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE,
"Max value recurrent cell can take before being clipped.")
# This flag is used for an experiment where one sees if training a model with
# many days data can be used to learn the dynamics from a held-out days data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \
affine functions.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
# OVERFITTING
# Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors.
flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.")
# It appears that the system will happily fit spikes (blessing or
# curse, depending). You may not want this. Jittering the spikes a
# bit will help (-/+ bin size, as specified here).
flags.DEFINE_integer("temporal_spike_jitter_width",
TEMPORAL_SPIKE_JITTER_WIDTH,
"Shuffle spikes around this window.")
# General note about helping ascribe controller inputs vs dynamics:
#
# If controller is heavily penalized, then it won't have any output.
# If dynamics are heavily penalized, then generator won't make
# dynamics. Note this l2 penalty is only on the recurrent portion of
# the RNNs, as dropout is also available, penalizing the feed-forward
# connections.
flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE,
"L2 regularization cost for the generator only.")
flags.DEFINE_float("l2_con_scale", L2_CON_SCALE,
"L2 regularization cost for the controller only.")
flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE,
"Cost of correlation (thru time)in the means of \
controller output.")
# UNDERFITTING
# If the primary task of LFADS is "filtering" of data and not
# generation, then it is possible that the KL penalty is too strong.
# Empirically, we have found this to be the case. So we add a
# hyperparameter in front of the the two KL terms (one for the initial
# conditions to the generator, the other for the controller outputs).
# You should always think of the the default values as 1.0, and that
# leads to a standard VAE formulation whereby the numbers that are
# optimized are a lower-bound on the log-likelihood of the data. When
# these 2 HPs deviate from 1.0, one cannot make any statement about
# what those LL lower bounds mean anymore, and they cannot be compared
# (AFAIK).
flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT,
"Strength of KL weight on initial conditions KL penatly.")
flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT,
"Strength of KL weight on controller output KL penalty.")
# Sometimes the task can be sufficiently hard to learn that the
# optimizer takes the 'easy route', and simply minimizes the KL
# divergence, setting it to near zero, and the optimization gets
# stuck. These two parameters will help avoid that by by getting the
# optimization to 'latch' on to the main optimization, and only
# turning in the regularizers later.
flags.DEFINE_integer("kl_start_step", KL_START_STEP,
"Start increasing weight after this many steps.")
# training passes, not epochs, increase by 0.5 every kl_increase_steps
flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS,
"Increase weight of kl cost to avoid local minimum.")
# Same story for l2 regularizer. One wants a simple generator, for scientific
# reasons, but not at the expense of hosing the optimization.
flags.DEFINE_integer("l2_start_step", L2_START_STEP,
"Start increasing l2 weight after this many steps.")
flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS,
"Increase weight of l2 cost to avoid local minimum.")
FLAGS = flags.FLAGS
def build_model(hps, kind="train", datasets=None):
"""Builds a model from either random initialization, or saved parameters.
Args:
hps: The hyper parameters for the model.
kind: (optional) The kind of model to build. Training vs inference require
different graphs.
datasets: The datasets structure (see top of lfads.py).
Returns:
an LFADS model.
"""
build_kind = kind
if build_kind == "write_model_params":
build_kind = "train"
with tf.variable_scope("LFADS", reuse=None):
model = LFADS(hps, kind=build_kind, datasets=datasets)
if not os.path.exists(hps.lfads_save_dir):
print("Save directory %s does not exist, creating it." % hps.lfads_save_dir)
os.makedirs(hps.lfads_save_dir)
cp_pb_ln = hps.checkpoint_pb_load_name
cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln
if cp_pb_ln == 'checkpoint':
print("Loading latest training checkpoint in: ", hps.lfads_save_dir)
saver = model.seso_saver
elif cp_pb_ln == 'checkpoint_lve':
print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir)
saver = model.lve_saver
else:
print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir)
saver = model.seso_saver
ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir,
latest_filename=cp_pb_ln)
session = tf.get_default_session()
print("ckpt: ", ckpt)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
if kind in ["posterior_sample_and_average", "prior_sample",
"write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \
initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?")
tf.global_variables_initializer().run()
if ckpt:
train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group()
else:
train_step_str = '-0'
fname = 'hyperparameters' + train_step_str + '.txt'
hp_fname = os.path.join(hps.lfads_save_dir, fname)
hps_for_saving = jsonify_dict(hps)
utils.write_data(hp_fname, hps_for_saving, use_json=True)
return model
def jsonify_dict(d):
"""Turns python booleans into strings so hps dict can be written in json.
Creates a shallow-copied dictionary first, then accomplishes string
conversion.
Args:
d: hyperparameter dictionary
Returns: hyperparameter dictionary with bool's as strings
"""
d2 = d.copy() # shallow copy is fine by assumption of d being shallow
def jsonify_bool(boolean_value):
if boolean_value:
return "true"
else:
return "false"
for key in d2.keys():
if isinstance(d2[key], bool):
d2[key] = jsonify_bool(d2[key])
return d2
def build_hyperparameter_dict(flags):
"""Simple script for saving hyper parameters. Under the hood the
flags structure isn't a dictionary, so it has to be simplified since we
want to be able to view file as text.
Args:
flags: From tf.app.flags
Returns:
dictionary of hyper parameters (ignoring other flag types).
"""
d = {}
# Data
d['output_dist'] = flags.output_dist
d['data_dir'] = flags.data_dir
d['lfads_save_dir'] = flags.lfads_save_dir
d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name
d['checkpoint_name'] = flags.checkpoint_name
d['output_filename_stem'] = flags.output_filename_stem
d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep
d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve
d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process
d['ext_input_dim'] = flags.ext_input_dim
d['data_filename_stem'] = flags.data_filename_stem
d['device'] = flags.device
d['csv_log'] = flags.csv_log
d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic
d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen
# Cell
d['cell_weight_scale'] = flags.cell_weight_scale
# Generation
d['ic_dim'] = flags.ic_dim
d['factors_dim'] = flags.factors_dim
d['ic_enc_dim'] = flags.ic_enc_dim
d['gen_dim'] = flags.gen_dim
d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale
d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale
# KL distributions
d['ic_prior_var_min'] = flags.ic_prior_var_min
d['ic_prior_var_scale'] = flags.ic_prior_var_scale
d['ic_prior_var_max'] = flags.ic_prior_var_max
d['ic_post_var_min'] = flags.ic_post_var_min
d['co_prior_var_scale'] = flags.co_prior_var_scale
d['prior_ar_atau'] = flags.prior_ar_atau
d['prior_ar_nvar'] = flags.prior_ar_nvar
d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau
d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar
# Controller
d['do_causal_controller'] = flags.do_causal_controller
d['controller_input_lag'] = flags.controller_input_lag
d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller
d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates
d['co_dim'] = flags.co_dim
d['ci_enc_dim'] = flags.ci_enc_dim
d['con_dim'] = flags.con_dim
d['co_mean_corr_scale'] = flags.co_mean_corr_scale
# Optimization
d['batch_size'] = flags.batch_size
d['learning_rate_init'] = flags.learning_rate_init
d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor
d['learning_rate_stop'] = flags.learning_rate_stop
d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare
d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
# Overfitting
d['keep_prob'] = flags.keep_prob
d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width
d['l2_gen_scale'] = flags.l2_gen_scale
d['l2_con_scale'] = flags.l2_con_scale
# Underfitting
d['kl_ic_weight'] = flags.kl_ic_weight
d['kl_co_weight'] = flags.kl_co_weight
d['kl_start_step'] = flags.kl_start_step
d['kl_increase_steps'] = flags.kl_increase_steps
d['l2_start_step'] = flags.l2_start_step
d['l2_increase_steps'] = flags.l2_increase_steps
return d
class hps_dict_to_obj(dict):
"""Helper class allowing us to access hps dictionary more easily."""
def __getattr__(self, key):
if key in self:
return self[key]
else:
assert False, ("%s does not exist." % key)
def __setattr__(self, key, value):
self[key] = value
def train(hps, datasets):
"""Train the LFADS model.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
model = build_model(hps, kind="train", datasets=datasets)
if hps.do_reset_learning_rate:
sess = tf.get_default_session()
sess.run(model.learning_rate.initializer)
model.train_model(datasets)
def write_model_runs(hps, datasets, output_fname=None):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The rates for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs.
"""
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_runs(datasets, output_fname)
def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
"""Use the prior distribution to generate samples from the model.
Generates batch_size number of samples (set through FLAGS).
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from. Only a concern with models trained on
multi-session data. By default, uses the first dataset in the data dict.
output_fname: The name prefix of the file in which to save the generated
samples.
"""
if not output_fname:
output_fname = "model_runs_" + hps.kind
else:
output_fname = output_fname + "model_runs_" + hps.kind
if not dataset_name:
dataset_name = datasets.keys()[0]
else:
if dataset_name not in datasets.keys():
raise ValueError("Invalid dataset name '%s'."%(dataset_name))
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_samples(dataset_name, output_fname)
def write_model_parameters(hps, output_fname=None, datasets=None):
"""Save all the model parameters
Save all the parameters to hps.lfads_save_dir.
Args:
hps: The dictionary of hyperparameters.
output_fname: The prefix of the file in which to save the generated
samples.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
if not output_fname:
output_fname = "model_params"
else:
output_fname = output_fname + "_model_params"
fname = os.path.join(hps.lfads_save_dir, output_fname)
print("Writing model parameters to: ", fname)
# save the optimizer params as well
model = build_model(hps, kind="write_model_params", datasets=datasets)
model_params = model.eval_model_parameters(use_nested=False,
include_strs="LFADS")
utils.write_data(fname, model_params, compression=None)
print("Done.")
def clean_data_dict(data_dict):
"""Add some key/value pairs to the data dict, if they are missing.
Args:
data_dict - dictionary containing data for LFADS
Returns:
data_dict with some keys filled in, if they are absent.
"""
keys = ['train_truth', 'train_ext_input', 'valid_data',
'valid_truth', 'valid_ext_input', 'valid_train']
for k in keys:
if k not in data_dict:
data_dict[k] = None
return data_dict
def load_datasets(data_dir, data_filename_stem):
"""Load the datasets from a specified directory.
Example files look like
>data_dir/my_dataset_first_day
>data_dir/my_dataset_second_day
If my_dataset (filename) stem is in the directory, the read routine will try
and load it. The datasets dictionary will then look like
dataset['first_day'] -> (first day data dictionary)
dataset['second_day'] -> (first day data dictionary)
Args:
data_dir: The directory from which to load the datasets.
data_filename_stem: The stem of the filename for the datasets.
Returns:
datasets: a dataset dictionary, with one name->data dictionary pair for
each dataset file.
"""
print("Reading data from ", data_dir)
datasets = utils.read_datasets(data_dir, data_filename_stem)
for k, data_dict in datasets.items():
datasets[k] = clean_data_dict(data_dict)
train_total_size = len(data_dict['train_data'])
if train_total_size == 0:
print("Did not load training set.")
else:
print("Found training set with number examples: ", train_total_size)
valid_total_size = len(data_dict['valid_data'])
if valid_total_size == 0:
print("Did not load validation set.")
else:
print("Found validation set with number examples: ", valid_total_size)
return datasets
def main(_):
"""Get this whole shindig off the ground."""
d = build_hyperparameter_dict(FLAGS)
hps = hps_dict_to_obj(d) # hyper parameters
kind = FLAGS.kind
# Read the data, if necessary.
train_set = valid_set = None
if kind in ["train", "posterior_sample_and_average", "prior_sample",
"write_model_params"]:
datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
else:
raise ValueError('Kind {} is not supported.'.format(kind))
# infer the dataset names and dataset dimensions from the loaded files
hps.kind = kind # needs to be added here, cuz not saved as hyperparam
hps.dataset_names = []
hps.dataset_dims = {}
for key in datasets:
hps.dataset_names.append(key)
hps.dataset_dims[key] = datasets[key]['data_dim']
# also store down the dimensionality of the data
# - just pull from one set, required to be same for all sets
hps.num_steps = datasets.values()[0]['num_steps']
hps.ndatasets = len(hps.dataset_names)
if hps.num_steps_for_gen_ic > hps.num_steps:
hps.num_steps_for_gen_ic = hps.num_steps
# Build and run the model, for varying purposes.
config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
if FLAGS.allow_gpu_growth:
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
with sess.as_default():
with tf.device(hps.device):
if kind == "train":
train(hps, datasets)
elif kind == "posterior_sample_and_average":
write_model_runs(hps, datasets, hps.output_filename_stem)
elif kind == "prior_sample":
write_model_samples(hps, datasets, hps.output_filename_stem)
elif kind == "write_model_params":
write_model_parameters(hps, hps.output_filename_stem, datasets)
else:
assert False, ("Kind %s is not implemented. " % kind)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf # used for flags here
from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
matplotlib.rcParams['image.interpolation'] = 'nearest'
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.")
flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nreplications", 40,
"Number of noise replications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("input_magnitude", 20.0,
"For the input case, what is the value of the input?")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
# Note that with N small, (as it is 25 above), the finite size effects
# will have pretty dramatic effects on the dynamics of the random RNN.
# If you want more complex dynamics, you'll have to run the script a
# lot, or increase N (or g).
# Getting hard vs. easy data can be a little stochastic, so we set the seed.
# Pull out some commonly used parameters.
# These are user parameters (configuration)
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
S = FLAGS.S
input_magnitude = FLAGS.input_magnitude
nreplications = FLAGS.nreplications
E = nreplications * C # total number of trials
# S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations.
ndatasets = N/S # ok if rounded down
train_percentage = FLAGS.train_percentage
ntime_steps = int(T / FLAGS.dt)
# End of user parameters
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper.
if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nreplications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nreplications should be integral number.'
# Initial condition generation, and condition label generation. This
# happens outside of the dataset loop, so that all datasets have the
# same conditions, which is similar to a neurophys setup.
condition_number = 0
x0s = []
condition_labels = []
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
# replicate the condition label nreplications times
for ns in range(nreplications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
# Containers for storing data across data.
datasets = {}
for n in range(ndatasets):
print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all
# replications this allows the random state for rate generation to be
# independent of n_replications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N:
dataset_name += '_n' + str(n+1)
# Sample neuron subsets. The assumption is the PC axes of the RNN
# are not unit aligned, so sampling units is adequate to sample all
# the high-variance PCs.
P_sxn = np.eye(S,N)
for m in range(n):
P_sxn = np.roll(P_sxn, S, axis=1)
if input_magnitude > 0.0:
# time of "hits" randomly chosen between [1/4 and 3/4] of total time
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
else:
input_times = None
rates, x0s, inputs = \
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude,
input_times=input_times)
if FLAGS.noise_type == "poisson":
noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
elif FLAGS.noise_type == "gaussian":
noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
else:
raise ValueError("Only noise types supported are poisson or gaussian")
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nreplications)
# Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds)
noisy_data_train, noisy_data_valid = \
split_list_by_inds(noisy_data, train_inds, valid_inds)
input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \
split_list_by_inds(condition_labels, train_inds, valid_inds)
input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, noisy_data, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
noisy_data_train = nparray_and_transpose(noisy_data_train)
noisy_data_valid = nparray_and_transpose(noisy_data_valid)
input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the noisy
# data e.g. spike trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid,
'train_data' : noisy_data_train,
'valid_data' : noisy_data_valid,
'train_percentage' : train_percentage,
'nreplications' : nreplications,
'dt' : rnn['dt'],
'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train,
'input_times_valid' : input_times_valid,
'P_sxn' : P_sxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn['conversion_factor']}
datasets[dataset_name] = data
if S < N:
# Note that this isn't necessary for this synthetic example, but
# it's useful to see how the input factor matrices were initialized
# for actual neurophysiology data.
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
# Write out the datasets.
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf
from utils import write_datasets
from synthetic_data_utils import normalize_rates
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "itb_rnn",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0,
"Map 1.0 of RNN to a spikes per second")
flags.DEFINE_float("u_std", 0.25,
"Std dev of input to integration to bound model")
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
"""Path to directory with checkpoints of model
trained on integration to bound task. Currently this
is a placeholder which tells the code to grab the
checkpoint that is provided with the code
(in /trained_itb/..). If you have your own checkpoint
you would like to restore, you would point it to
that path.""")
FLAGS = flags.FLAGS
class IntegrationToBoundModel:
def __init__(self, N):
scale = 0.8 / float(N**0.5)
self.N = N
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
self.b_1xn = tf.Variable(tf.zeros([1, N]))
self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
self.bro_o = tf.Variable(tf.zeros([1]))
def call(self, h_tm1_bxn, u_bx1):
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
h_t_bxn = tf.nn.tanh(act_t_bxn)
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
return z_t, h_t_bxn
def get_data_batch(batch_size, T, rng, u_std):
u_bxt = rng.randn(batch_size, T) * u_std
running_sum_b = np.zeros([batch_size])
labels_bxt = np.zeros([batch_size, T])
for t in xrange(T):
running_sum_b += u_bxt[:, t]
labels_bxt[:, t] += running_sum_b
labels_bxt = np.clip(labels_bxt, -1, 1)
return u_bxt, labels_bxt
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial
model = IntegrationToBoundModel(N)
inputs_ph_t = [tf.placeholder(tf.float32,
shape=[None, 1]) for _ in range(ntimesteps)]
state = tf.zeros([batch_size, N])
saver = tf.train.Saver()
P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections
# unroll RNN for T timesteps
outputs_t = []
states_t = []
for inp in inputs_ph_t:
output, state = model.call(state, inp)
outputs_t.append(output)
states_t.append(state)
with tf.Session() as sess:
# restore the latest model ckpt
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
dir_path = os.path.dirname(os.path.realpath(__file__))
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
else:
model_checkpoint_path = FLAGS.checkpoint_path
try:
saver.restore(sess, model_checkpoint_path)
print ('Model restored from', model_checkpoint_path)
except:
assert False, ("No checkpoints to restore from, is the path %s correct?"
%model_checkpoint_path)
# generate data for trials
data_e = []
u_e = []
outs_e = []
for c in range(C):
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
feed_dict = {}
for t in xrange(ntimesteps):
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
feed_dict=feed_dict)
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications):
data_e.append(r_sxt)
u_e.append(u_1xt)
outs_e.append(outputs_t_bxn)
truth_data_e = normalize_rates(data_e, E, N)
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds,
valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
train_inds,
valid_inds)
data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
train_inds,
valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)
# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
train_inds,
valid_inds)
train_outputs_u = np.array(train_outputs_u)
valid_outputs_u = np.array(valid_outputs_u)
data = { 'train_truth': data_train_truth,
'valid_truth': data_valid_truth,
'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate,
'train_inputs_u': train_inputs_u,
'valid_inputs_u': valid_inputs_u,
'train_outputs_u': train_outputs_u,
'valid_outputs_u': valid_outputs_u,
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import numpy as np
from synthetic_data_utils import generate_data, generate_rnn
from synthetic_data_utils import get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import tensorflow as tf
from utils import write_datasets
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "conditioned_rnn_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
np.random.RandomState(seed=FLAGS.synth_data_seed+2)]
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
nspikifications = FLAGS.nspikifications
E = nspikifications * C
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnns = [rnn_a, rnn_b]
# pick which RNN is used on each trial
rnn_to_use = rng.randint(2, size=E)
ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1)
ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper
x0s = []
condition_labels = []
condition_number = 0
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications))
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
P_nxn = rng.randn(N, N) / np.sqrt(N)
# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])
rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])
# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
if rnn_to_use[trial] == 0:
rates.append(rates_a[trial])
spikes.append(spikes_a[trial])
else:
rates.append(rates_b[trial])
spikes.append(spikes_b[trial])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(
ext_input, train_inds, valid_inds)
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
# add train_ext_input and valid_ext input
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn_a['conversion_factor']}
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs_gaussian --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
Args:
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
Returns:
the dictionary of these parameters, plus some others.
"""
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
input_times=None):
""" Generates data from an randomly initialized RNN.
Args:
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
Returns:
A list of length E of NxT tensors of the network being run.
"""
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt = np.dot(P_sxn, r_nxt)
inputs_e.append(u_uxt)
data_e.append(r_sxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(np.int)
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
spikes_e.append(data_s)
return spikes_e
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply gaussian noise to a continuous dataset whose values are between
0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
gauss_e: a list of length b of the data with noise.
"""
E = len(data_e)
mfr = max_firing_rate
gauss_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
gauss_e.append(noisy_data)
return gauss_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications:
valid_inds.append(i)
else:
train_inds.append(i)
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
inds2.
Args:
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
"""
if data is None or len(data) == 0:
return [], []
else:
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
Args:
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
Returns:
a numpy 3-tensor with dimensions a x c x b
"""
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
system.
Args:
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
Returns:
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
"""
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
avg_data_all[name][cname].T
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.figure()
plt.plot(np.log10(evals_n), '-x')
plt.figure()
plt.subplot(311)
plt.imshow(all_data_pca_pxtc)
plt.colorbar()
plt.subplot(312)
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
plt.colorbar()
plt.subplot(313)
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
plt.colorbar()
import pdb
pdb.set_trace()
return datasets
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment