Commit 44fa1d37 authored by Alex Lee's avatar Alex Lee
Browse files

Merge remote-tracking branch 'upstream/master'

parents d3628a74 6e367f67
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from lfads import LFADS
import numpy as np
import os
import tensorflow as tf
import re
import utils
# Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags
# session.
CHECKPOINT_PB_LOAD_NAME = "checkpoint"
CHECKPOINT_NAME = "lfads_vae"
CSV_LOG = "fitlog"
OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all
EXT_INPUT_DIM = 0
IC_DIM = 64
FACTORS_DIM = 50
IC_ENC_DIM = 128
GEN_DIM = 200
GEN_CELL_INPUT_WEIGHT_SCALE = 1.0
GEN_CELL_REC_WEIGHT_SCALE = 1.0
CELL_WEIGHT_SCALE = 1.0
BATCH_SIZE = 128
LEARNING_RATE_INIT = 0.01
LEARNING_RATE_DECAY_FACTOR = 0.95
LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
# Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0
CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/"
CO_DIM = 1
DO_CAUSAL_CONTROLLER = False
DO_FEED_FACTORS_TO_CONTROLLER = True
CONTROLLER_INPUT_LAG = 1
PRIOR_AR_AUTOCORRELATION = 10.0
PRIOR_AR_PROCESS_VAR = 0.1
DO_TRAIN_PRIOR_AR_ATAU = True
DO_TRAIN_PRIOR_AR_NVAR = True
CI_ENC_DIM = 128
CON_DIM = 128
CO_PRIOR_VAR_SCALE = 0.1
KL_INCREASE_STEPS = 2000
L2_INCREASE_STEPS = 2000
L2_GEN_SCALE = 2000.0
L2_CON_SCALE = 0.0
# scale of regularizer on time correlation of inferred inputs
CO_MEAN_CORR_SCALE = 0.0
KL_IC_WEIGHT = 1.0
KL_CO_WEIGHT = 1.0
KL_START_STEP = 0
L2_START_STEP = 0
IC_PRIOR_VAR_MIN = 0.1
IC_PRIOR_VAR_SCALE = 0.1
IC_PRIOR_VAR_MAX = 0.1
IC_POST_VAR_MIN = 0.0001 # protection from KL blowing up
flags = tf.app.flags
flags.DEFINE_string("kind", "train",
"Type of model to build {train, \
posterior_sample_and_average, \
prior_sample, write_model_params")
flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
"Type of output distribution, 'poisson' or 'gaussian'")
flags.DEFINE_boolean("allow_gpu_growth", False,
"If true, only allocate amount of memory needed for \
Session. Otherwise, use full GPU memory.")
# DATA
flags.DEFINE_string("data_dir", DATA_DIR, "Data for training")
flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM,
"Filename stem for data dictionaries.")
flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir")
flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME,
"Name of checkpoint files, use 'checkpoint_lve' for best \
error")
flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME,
"Name of checkpoint files (.ckpt appended)")
flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM,
"Name of output file (postfix will be added)")
flags.DEFINE_string("device", DEVICE,
"Which device to use (default: \"gpu:0\", can also be \
\"cpu:0\", \"gpu:1\", etc)")
flags.DEFINE_string("csv_log", CSV_LOG,
"Name of file to keep running log of fit likelihoods, \
etc (.csv appended)")
flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP,
"Max # of checkpoints to keep (rolling)")
flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS,
"Number of examples to process for posterior sample and \
average (not number of samples to average over).")
flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE,
"Max # of checkpoints to keep for lowest validation error \
models (rolling)")
flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \
inputs")
flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC,
"Number of steps to train the generator initial conditon.")
# If there are observed inputs, there are two ways to add that observed
# input to the model. The first is by treating as something to be
# inferred, and thus encoding the observed input via the encoders, and then
# input to the generator via the "inferred inputs" channel. Second, one
# can input the input directly into the generator. This has the downside
# of making the generation process strictly dependent on knowing the
# observed input for any generated trial.
flags.DEFINE_boolean("inject_ext_input_to_gen",
INJECT_EXT_INPUT_TO_GEN,
"Should observed inputs be input to model via encoders, \
or injected directly into generator?")
# CELL
# The combined recurrent and input weights of the encoder and
# controller cells are by default set to scale at ws/sqrt(#inputs),
# with ws=1.0. You can change this scaling with this parameter.
flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
# GENERATION
# Note that the dimension of the initial conditions is separated from the
# dimensions of the generator initial conditions (and a linear matrix will
# adapt the shapes if necessary). This is just another way to control
# complexity. In all likelihood, setting the ic dims to the size of the
# generator hidden state is just fine.
flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0")
# Setting the dimensions of the factors to something smaller than the data
# dimension is a way to get a reduced dimensionality representation of your
# data.
flags.DEFINE_integer("factors_dim", FACTORS_DIM,
"Number of factors from generator")
flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM,
"Cell hidden size, encoder of h0")
# Controlling the size of the generator is one way to control complexity of
# the dynamics (there is also l2, which will squeeze out unnecessary
# dynamics also). The modern deep learning approach is to make these cells
# as large as tolerable (from a waiting perspective), and then regularize
# them to death with drop out or whatever. I don't know if this is correct
# for the LFADS application or not.
flags.DEFINE_integer("gen_dim", GEN_DIM,
"Cell hidden size, generator.")
# The weights of the generator cell by default set to scale at
# ws/sqrt(#inputs), with ws=1.0. You can change ws for
# the input weights or the recurrent weights with these hyperparameters.
flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE,
"Input scaling for input weights in generator.")
flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE,
"Input scaling for rec weights in generator.")
# KL DISTRIBUTIONS
# If you don't know what you are donig here, please leave alone, the
# defaults should be fine for most cases, irregardless of other parameters.
#
# If you don't want the prior variance to be learned, set the
# following values to the same thing: ic_prior_var_min,
# ic_prior_var_scale, ic_prior_var_max. The prior mean will be
# learned regardless.
flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN,
"Minimum variance in posterior h0 codes.")
flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE,
"Variance of ic prior distribution")
flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX,
"Maximum variance of IC prior distribution.")
# If you really want to limit the information from encoder to decoder,
# Increase ic_post_var_min above 0.0.
flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN,
"Minimum variance of IC posterior distribution.")
flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
"Variance of control input prior distribution.")
flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
"Initial autocorrelation of AR(1) priors.")
flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
"Initial noise variance for AR(1) priors.")
flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
"Is the value for atau an init, or the constant value?")
flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR,
"Is the value for noise variance an init, or the constant \
value?")
# CONTROLLER
# This parameter critically controls whether or not there is a controller
# (along with controller encoders placed into the LFADS graph. If CO_DIM >
# 1, that means there is a 1 dimensional controller outputs, if equal to 0,
# then no controller.
flags.DEFINE_integer("co_dim", CO_DIM,
"Number of control net outputs (>0 builds that graph).")
# The controller will be more powerful if it can see the encoding of the entire
# trial. However, this allows the controller to create inferred inputs that are
# acausal with respect to the actual data generation process. E.g. the data
# generator could have an input at time t, but the controller, after seeing the
# entirety of the trial could infer that the input is coming a little before
# time t, because there are no restrictions on the data the controller sees.
# One can force the controller to be causal (with respect to perturbations in
# the data generator) so that it only sees forward encodings of the data at time
# t that originate at times before or at time t. One can also control the data
# the controller sees by using an input lag (forward encoding at time [t-tlag]
# for controller input at time t. The same can be done in the reverse direction
# (controller input at time t from reverse encoding at time [t+tlag], in the
# case of an acausal controller). Setting this lag > 0 (even lag=1) can be a
# powerful way of avoiding very spiky decodes. Finally, one can manually control
# whether the factors at time t-1 are fed to the controller at time t.
#
# If you don't care about any of this, and just want to smooth your data, set
# do_causal_controller = False
# do_feed_factors_to_controller = True
# causal_input_lag = 0
flags.DEFINE_boolean("do_causal_controller",
DO_CAUSAL_CONTROLLER,
"Restrict the controller create only causal inferred \
inputs?")
# Strictly speaking, feeding either the factors or the rates to the controller
# violates causality, since the g0 gets to see all the data. This may or may not
# be only a theoretical concern.
flags.DEFINE_boolean("do_feed_factors_to_controller",
DO_FEED_FACTORS_TO_CONTROLLER,
"Should factors[t-1] be input to controller at time t?")
flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
"Feedback the factors or the rates to the controller? \
Acceptable values: 'factors' or 'rates'.")
flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
"Time lag on the encoding to controller t-lag for \
forward, t+lag for reverse.")
flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM,
"Cell hidden size, encoder of control inputs")
flags.DEFINE_integer("con_dim", CON_DIM,
"Cell hidden size, controller")
# OPTIMIZATION
flags.DEFINE_integer("batch_size", BATCH_SIZE,
"Batch size to use during training.")
flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT,
"Learning rate initial value")
flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR,
"Learning rate decay, decay by this fraction every so \
often.")
flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP,
"The lr is adaptively reduced, stop training at this value.")
# Rather put the learning rate on an exponentially decreasiong schedule,
# the current algorithm pays attention to the learning rate, and if it
# isn't regularly decreasing, it will decrease the learning rate. So far,
# it works fine, though it is not perfect.
flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE,
"Number of previous costs current cost has to be worse \
than, to lower learning rate.")
# This sets a value, above which, the gradients will be clipped. This hp
# is extremely useful to avoid an infrequent, but highly pathological
# problem whereby the gradient is so large that it destroys the
# optimziation by setting parameters too large, leading to a vicious cycle
# that ends in NaNs. If it's too large, it's useless, if it's too small,
# it essentially becomes the learning rate. It's pretty insensitive, though.
flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM,
"Max norm of gradient before clipping.")
# If your optimizations start "NaN-ing out", reduce this value so that
# the values of the network don't grow out of control. Typically, once
# this parameter is set to a reasonable value, one stops having numerical
# problems.
flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE,
"Max value recurrent cell can take before being clipped.")
# This flag is used for an experiment where one sees if training a model with
# many days data can be used to learn the dynamics from a held-out days data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \
affine functions.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
# OVERFITTING
# Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors.
flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.")
# It appears that the system will happily fit spikes (blessing or
# curse, depending). You may not want this. Jittering the spikes a
# bit will help (-/+ bin size, as specified here).
flags.DEFINE_integer("temporal_spike_jitter_width",
TEMPORAL_SPIKE_JITTER_WIDTH,
"Shuffle spikes around this window.")
# General note about helping ascribe controller inputs vs dynamics:
#
# If controller is heavily penalized, then it won't have any output.
# If dynamics are heavily penalized, then generator won't make
# dynamics. Note this l2 penalty is only on the recurrent portion of
# the RNNs, as dropout is also available, penalizing the feed-forward
# connections.
flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE,
"L2 regularization cost for the generator only.")
flags.DEFINE_float("l2_con_scale", L2_CON_SCALE,
"L2 regularization cost for the controller only.")
flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE,
"Cost of correlation (thru time)in the means of \
controller output.")
# UNDERFITTING
# If the primary task of LFADS is "filtering" of data and not
# generation, then it is possible that the KL penalty is too strong.
# Empirically, we have found this to be the case. So we add a
# hyperparameter in front of the the two KL terms (one for the initial
# conditions to the generator, the other for the controller outputs).
# You should always think of the the default values as 1.0, and that
# leads to a standard VAE formulation whereby the numbers that are
# optimized are a lower-bound on the log-likelihood of the data. When
# these 2 HPs deviate from 1.0, one cannot make any statement about
# what those LL lower bounds mean anymore, and they cannot be compared
# (AFAIK).
flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT,
"Strength of KL weight on initial conditions KL penatly.")
flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT,
"Strength of KL weight on controller output KL penalty.")
# Sometimes the task can be sufficiently hard to learn that the
# optimizer takes the 'easy route', and simply minimizes the KL
# divergence, setting it to near zero, and the optimization gets
# stuck. These two parameters will help avoid that by by getting the
# optimization to 'latch' on to the main optimization, and only
# turning in the regularizers later.
flags.DEFINE_integer("kl_start_step", KL_START_STEP,
"Start increasing weight after this many steps.")
# training passes, not epochs, increase by 0.5 every kl_increase_steps
flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS,
"Increase weight of kl cost to avoid local minimum.")
# Same story for l2 regularizer. One wants a simple generator, for scientific
# reasons, but not at the expense of hosing the optimization.
flags.DEFINE_integer("l2_start_step", L2_START_STEP,
"Start increasing l2 weight after this many steps.")
flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS,
"Increase weight of l2 cost to avoid local minimum.")
FLAGS = flags.FLAGS
def build_model(hps, kind="train", datasets=None):
"""Builds a model from either random initialization, or saved parameters.
Args:
hps: The hyper parameters for the model.
kind: (optional) The kind of model to build. Training vs inference require
different graphs.
datasets: The datasets structure (see top of lfads.py).
Returns:
an LFADS model.
"""
build_kind = kind
if build_kind == "write_model_params":
build_kind = "train"
with tf.variable_scope("LFADS", reuse=None):
model = LFADS(hps, kind=build_kind, datasets=datasets)
if not os.path.exists(hps.lfads_save_dir):
print("Save directory %s does not exist, creating it." % hps.lfads_save_dir)
os.makedirs(hps.lfads_save_dir)
cp_pb_ln = hps.checkpoint_pb_load_name
cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln
if cp_pb_ln == 'checkpoint':
print("Loading latest training checkpoint in: ", hps.lfads_save_dir)
saver = model.seso_saver
elif cp_pb_ln == 'checkpoint_lve':
print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir)
saver = model.lve_saver
else:
print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir)
saver = model.seso_saver
ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir,
latest_filename=cp_pb_ln)
session = tf.get_default_session()
print("ckpt: ", ckpt)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
if kind in ["posterior_sample_and_average", "prior_sample",
"write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \
initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?")
tf.global_variables_initializer().run()
if ckpt:
train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group()
else:
train_step_str = '-0'
fname = 'hyperparameters' + train_step_str + '.txt'
hp_fname = os.path.join(hps.lfads_save_dir, fname)
hps_for_saving = jsonify_dict(hps)
utils.write_data(hp_fname, hps_for_saving, use_json=True)
return model
def jsonify_dict(d):
"""Turns python booleans into strings so hps dict can be written in json.
Creates a shallow-copied dictionary first, then accomplishes string
conversion.
Args:
d: hyperparameter dictionary
Returns: hyperparameter dictionary with bool's as strings
"""
d2 = d.copy() # shallow copy is fine by assumption of d being shallow
def jsonify_bool(boolean_value):
if boolean_value:
return "true"
else:
return "false"
for key in d2.keys():
if isinstance(d2[key], bool):
d2[key] = jsonify_bool(d2[key])
return d2
def build_hyperparameter_dict(flags):
"""Simple script for saving hyper parameters. Under the hood the
flags structure isn't a dictionary, so it has to be simplified since we
want to be able to view file as text.
Args:
flags: From tf.app.flags
Returns:
dictionary of hyper parameters (ignoring other flag types).
"""
d = {}
# Data
d['output_dist'] = flags.output_dist
d['data_dir'] = flags.data_dir
d['lfads_save_dir'] = flags.lfads_save_dir
d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name
d['checkpoint_name'] = flags.checkpoint_name
d['output_filename_stem'] = flags.output_filename_stem
d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep
d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve
d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process
d['ext_input_dim'] = flags.ext_input_dim
d['data_filename_stem'] = flags.data_filename_stem
d['device'] = flags.device
d['csv_log'] = flags.csv_log
d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic
d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen
# Cell
d['cell_weight_scale'] = flags.cell_weight_scale
# Generation
d['ic_dim'] = flags.ic_dim
d['factors_dim'] = flags.factors_dim
d['ic_enc_dim'] = flags.ic_enc_dim
d['gen_dim'] = flags.gen_dim
d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale
d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale
# KL distributions
d['ic_prior_var_min'] = flags.ic_prior_var_min
d['ic_prior_var_scale'] = flags.ic_prior_var_scale
d['ic_prior_var_max'] = flags.ic_prior_var_max
d['ic_post_var_min'] = flags.ic_post_var_min
d['co_prior_var_scale'] = flags.co_prior_var_scale
d['prior_ar_atau'] = flags.prior_ar_atau
d['prior_ar_nvar'] = flags.prior_ar_nvar
d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau
d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar
# Controller
d['do_causal_controller'] = flags.do_causal_controller
d['controller_input_lag'] = flags.controller_input_lag
d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller
d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates
d['co_dim'] = flags.co_dim
d['ci_enc_dim'] = flags.ci_enc_dim
d['con_dim'] = flags.con_dim
d['co_mean_corr_scale'] = flags.co_mean_corr_scale
# Optimization
d['batch_size'] = flags.batch_size
d['learning_rate_init'] = flags.learning_rate_init
d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor
d['learning_rate_stop'] = flags.learning_rate_stop
d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare
d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
# Overfitting
d['keep_prob'] = flags.keep_prob
d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width
d['l2_gen_scale'] = flags.l2_gen_scale
d['l2_con_scale'] = flags.l2_con_scale
# Underfitting
d['kl_ic_weight'] = flags.kl_ic_weight
d['kl_co_weight'] = flags.kl_co_weight
d['kl_start_step'] = flags.kl_start_step
d['kl_increase_steps'] = flags.kl_increase_steps
d['l2_start_step'] = flags.l2_start_step
d['l2_increase_steps'] = flags.l2_increase_steps
return d
class hps_dict_to_obj(dict):
"""Helper class allowing us to access hps dictionary more easily."""
def __getattr__(self, key):
if key in self:
return self[key]
else:
assert False, ("%s does not exist." % key)
def __setattr__(self, key, value):
self[key] = value
def train(hps, datasets):
"""Train the LFADS model.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
model = build_model(hps, kind="train", datasets=datasets)
if hps.do_reset_learning_rate:
sess = tf.get_default_session()
sess.run(model.learning_rate.initializer)
model.train_model(datasets)
def write_model_runs(hps, datasets, output_fname=None):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The rates for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs.
"""
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_runs(datasets, output_fname)
def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
"""Use the prior distribution to generate samples from the model.
Generates batch_size number of samples (set through FLAGS).
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from. Only a concern with models trained on
multi-session data. By default, uses the first dataset in the data dict.
output_fname: The name prefix of the file in which to save the generated
samples.
"""
if not output_fname:
output_fname = "model_runs_" + hps.kind
else:
output_fname = output_fname + "model_runs_" + hps.kind
if not dataset_name:
dataset_name = datasets.keys()[0]
else:
if dataset_name not in datasets.keys():
raise ValueError("Invalid dataset name '%s'."%(dataset_name))
model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_samples(dataset_name, output_fname)
def write_model_parameters(hps, output_fname=None, datasets=None):
"""Save all the model parameters
Save all the parameters to hps.lfads_save_dir.
Args:
hps: The dictionary of hyperparameters.
output_fname: The prefix of the file in which to save the generated
samples.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
if not output_fname:
output_fname = "model_params"
else:
output_fname = output_fname + "_model_params"
fname = os.path.join(hps.lfads_save_dir, output_fname)
print("Writing model parameters to: ", fname)
# save the optimizer params as well
model = build_model(hps, kind="write_model_params", datasets=datasets)
model_params = model.eval_model_parameters(use_nested=False,
include_strs="LFADS")
utils.write_data(fname, model_params, compression=None)
print("Done.")
def clean_data_dict(data_dict):
"""Add some key/value pairs to the data dict, if they are missing.
Args:
data_dict - dictionary containing data for LFADS
Returns:
data_dict with some keys filled in, if they are absent.
"""
keys = ['train_truth', 'train_ext_input', 'valid_data',
'valid_truth', 'valid_ext_input', 'valid_train']
for k in keys:
if k not in data_dict:
data_dict[k] = None
return data_dict
def load_datasets(data_dir, data_filename_stem):
"""Load the datasets from a specified directory.
Example files look like
>data_dir/my_dataset_first_day
>data_dir/my_dataset_second_day
If my_dataset (filename) stem is in the directory, the read routine will try
and load it. The datasets dictionary will then look like
dataset['first_day'] -> (first day data dictionary)
dataset['second_day'] -> (first day data dictionary)
Args:
data_dir: The directory from which to load the datasets.
data_filename_stem: The stem of the filename for the datasets.
Returns:
datasets: a dataset dictionary, with one name->data dictionary pair for
each dataset file.
"""
print("Reading data from ", data_dir)
datasets = utils.read_datasets(data_dir, data_filename_stem)
for k, data_dict in datasets.items():
datasets[k] = clean_data_dict(data_dict)
train_total_size = len(data_dict['train_data'])
if train_total_size == 0:
print("Did not load training set.")
else:
print("Found training set with number examples: ", train_total_size)
valid_total_size = len(data_dict['valid_data'])
if valid_total_size == 0:
print("Did not load validation set.")
else:
print("Found validation set with number examples: ", valid_total_size)
return datasets
def main(_):
"""Get this whole shindig off the ground."""
d = build_hyperparameter_dict(FLAGS)
hps = hps_dict_to_obj(d) # hyper parameters
kind = FLAGS.kind
# Read the data, if necessary.
train_set = valid_set = None
if kind in ["train", "posterior_sample_and_average", "prior_sample",
"write_model_params"]:
datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
else:
raise ValueError('Kind {} is not supported.'.format(kind))
# infer the dataset names and dataset dimensions from the loaded files
hps.kind = kind # needs to be added here, cuz not saved as hyperparam
hps.dataset_names = []
hps.dataset_dims = {}
for key in datasets:
hps.dataset_names.append(key)
hps.dataset_dims[key] = datasets[key]['data_dim']
# also store down the dimensionality of the data
# - just pull from one set, required to be same for all sets
hps.num_steps = datasets.values()[0]['num_steps']
hps.ndatasets = len(hps.dataset_names)
if hps.num_steps_for_gen_ic > hps.num_steps:
hps.num_steps_for_gen_ic = hps.num_steps
# Build and run the model, for varying purposes.
config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
if FLAGS.allow_gpu_growth:
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
with sess.as_default():
with tf.device(hps.device):
if kind == "train":
train(hps, datasets)
elif kind == "posterior_sample_and_average":
write_model_runs(hps, datasets, hps.output_filename_stem)
elif kind == "prior_sample":
write_model_samples(hps, datasets, hps.output_filename_stem)
elif kind == "write_model_params":
write_model_parameters(hps, hps.output_filename_stem, datasets)
else:
assert False, ("Kind %s is not implemented. " % kind)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf # used for flags here
from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
matplotlib.rcParams['image.interpolation'] = 'nearest'
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 40,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("input_magnitude", 20.0,
"For the input case, what is the value of the input?")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
# Note that with N small, (as it is 25 above), the finite size effects
# will have pretty dramatic effects on the dynamics of the random RNN.
# If you want more complex dynamics, you'll have to run the script a
# lot, or increase N (or g).
# Getting hard vs. easy data can be a little stochastic, so we set the seed.
# Pull out some commonly used parameters.
# These are user parameters (configuration)
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
S = FLAGS.S
input_magnitude = FLAGS.input_magnitude
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
# S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations.
ndatasets = N/S # ok if rounded down
train_percentage = FLAGS.train_percentage
ntime_steps = int(T / FLAGS.dt)
# End of user parameters
rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper.
if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nspikifications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nspikifications should be integral number.'
# Initial condition generation, and condition label generation. This
# happens outside of the dataset loop, so that all datasets have the
# same conditions, which is similar to a neurophys setup.
condition_number = 0
x0s = []
condition_labels = []
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) # replicate x0 nspikifications times
# replicate the condition label nspikifications times
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
# Containers for storing data across data.
datasets = {}
for n in range(ndatasets):
print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all
# spikifications this allows the random state for rate generation to be
# independent of n_spikifications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N:
dataset_name += '_n' + str(n+1)
# Sample neuron subsets. The assumption is the PC axes of the RNN
# are not unit aligned, so sampling units is adequate to sample all
# the high-variance PCs.
P_sxn = np.eye(S,N)
for m in range(n):
P_sxn = np.roll(P_sxn, S, axis=1)
if input_magnitude > 0.0:
# time of "hits" randomly chosen between [1/4 and 3/4] of total time
input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
else:
input_times = None
rates, x0s, inputs = \
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude,
input_times=input_times)
spikes = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
# Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = \
split_list_by_inds(spikes, train_inds, valid_inds)
input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \
split_list_by_inds(condition_labels, train_inds, valid_inds)
input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, spikes, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the spike
# trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : rnn['dt'],
'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train,
'input_times_valid' : input_times_valid,
'P_sxn' : P_sxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn['conversion_factor']}
datasets[dataset_name] = data
if S < N:
# Note that this isn't necessary for this synthetic example, but
# it's useful to see how the input factor matrices were initialized
# for actual neurophysiology data.
datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
# Write out the datasets.
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
import tensorflow as tf
from utils import write_datasets
from synthetic_data_utils import normalize_rates
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "itb_rnn",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0,
"Map 1.0 of RNN to a spikes per second")
flags.DEFINE_float("u_std", 0.25,
"Std dev of input to integration to bound model")
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
"""Path to directory with checkpoints of model
trained on integration to bound task. Currently this
is a placeholder which tells the code to grab the
checkpoint that is provided with the code
(in /trained_itb/..). If you have your own checkpoint
you would like to restore, you would point it to
that path.""")
FLAGS = flags.FLAGS
class IntegrationToBoundModel:
def __init__(self, N):
scale = 0.8 / float(N**0.5)
self.N = N
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
self.b_1xn = tf.Variable(tf.zeros([1, N]))
self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
self.bro_o = tf.Variable(tf.zeros([1]))
def call(self, h_tm1_bxn, u_bx1):
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
h_t_bxn = tf.nn.tanh(act_t_bxn)
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
return z_t, h_t_bxn
def get_data_batch(batch_size, T, rng, u_std):
u_bxt = rng.randn(batch_size, T) * u_std
running_sum_b = np.zeros([batch_size])
labels_bxt = np.zeros([batch_size, T])
for t in xrange(T):
running_sum_b += u_bxt[:, t]
labels_bxt[:, t] += running_sum_b
labels_bxt = np.clip(labels_bxt, -1, 1)
return u_bxt, labels_bxt
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial
model = IntegrationToBoundModel(N)
inputs_ph_t = [tf.placeholder(tf.float32,
shape=[None, 1]) for _ in range(ntimesteps)]
state = tf.zeros([batch_size, N])
saver = tf.train.Saver()
P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections
# unroll RNN for T timesteps
outputs_t = []
states_t = []
for inp in inputs_ph_t:
output, state = model.call(state, inp)
outputs_t.append(output)
states_t.append(state)
with tf.Session() as sess:
# restore the latest model ckpt
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
dir_path = os.path.dirname(os.path.realpath(__file__))
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
else:
model_checkpoint_path = FLAGS.checkpoint_path
try:
saver.restore(sess, model_checkpoint_path)
print ('Model restored from', model_checkpoint_path)
except:
assert False, ("No checkpoints to restore from, is the path %s correct?"
%model_checkpoint_path)
# generate data for trials
data_e = []
u_e = []
outs_e = []
for c in range(C):
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
feed_dict = {}
for t in xrange(ntimesteps):
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
feed_dict=feed_dict)
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications):
data_e.append(r_sxt)
u_e.append(u_1xt)
outs_e.append(outputs_t_bxn)
truth_data_e = normalize_rates(data_e, E, N)
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds,
valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
train_inds,
valid_inds)
data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
train_inds,
valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)
# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
train_inds,
valid_inds)
train_outputs_u = np.array(train_outputs_u)
valid_outputs_u = np.array(valid_outputs_u)
data = { 'train_truth': data_train_truth,
'valid_truth': data_valid_truth,
'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate,
'train_inputs_u': train_inputs_u,
'valid_inputs_u': valid_inputs_u,
'train_outputs_u': train_outputs_u,
'valid_outputs_u': valid_outputs_u,
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import numpy as np
from synthetic_data_utils import generate_data, generate_rnn
from synthetic_data_utils import get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds
import tensorflow as tf
from utils import write_datasets
DATA_DIR = "rnn_synth_data_v1.0"
flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.")
flags.DEFINE_string("datafile_name", "conditioned_rnn_data",
"Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
FLAGS = flags.FLAGS
rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
np.random.RandomState(seed=FLAGS.synth_data_seed+2)]
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
nspikifications = FLAGS.nspikifications
E = nspikifications * C
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
rnn_a = generate_rnn(rnn_rngs[0], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnn_b = generate_rnn(rnn_rngs[1], N, FLAGS.g, FLAGS.tau, FLAGS.dt,
FLAGS.max_firing_rate)
rnns = [rnn_a, rnn_b]
# pick which RNN is used on each trial
rnn_to_use = rng.randint(2, size=E)
ext_input = np.repeat(np.expand_dims(rnn_to_use, axis=1), ntimesteps, axis=1)
ext_input = np.expand_dims(ext_input, axis=2) # these are "a's" in the paper
x0s = []
condition_labels = []
condition_number = 0
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications))
for ns in range(nspikifications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
P_nxn = rng.randn(N, N) / np.sqrt(N)
# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])
rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
input_magnitude=0.0, input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])
# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
if rnn_to_use[trial] == 0:
rates.append(rates_a[trial])
spikes.append(spikes_a[trial])
else:
rates.append(rates_b[trial])
spikes.append(spikes_b[trial])
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(
ext_input, train_inds, valid_inds)
rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)
# add train_ext_input and valid_ext input
data = {'train_truth': rates_train,
'valid_truth': rates_valid,
'train_data' : spikes_train,
'valid_data' : spikes_valid,
'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'dt' : FLAGS.dt,
'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train,
'condition_labels_valid' : condition_labels_valid,
'conversion_factor': 1.0 / rnn_a['conversion_factor']}
# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data
# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
FLAGS.datafile_name + '_' + dataset_name))
#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0
echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0
echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
Args:
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
Returns:
the dictionary of these parameters, plus some others.
"""
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
input_times=None):
""" Generates data from an randomly initialized RNN.
Args:
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
Returns:
A list of length E of NxT tensors of the network being run.
"""
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 + np.dot(W,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt = np.dot(P_sxn, r_nxt)
inputs_e.append(u_uxt)
data_e.append(r_sxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_data_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
spikifies_data_e = []
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(np.int)
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
spikes_e.append(data_s)
return spikes_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications:
valid_inds.append(i)
else:
train_inds.append(i)
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
inds2.
Args:
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
"""
if data is None or len(data) == 0:
return [], []
else:
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
Args:
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
Returns:
a numpy 3-tensor with dimensions a x c x b
"""
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
system.
Args:
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
Returns:
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
"""
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
np.unique(np.ndarray.flatten(np.array(conditions_all.values())))
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
avg_data_all[name][cname].T
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn = np.dot(all_data_zm_nxtc, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc = np.dot(evecs_nxn[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.figure()
plt.plot(np.log10(evals_n), '-x')
plt.figure()
plt.subplot(311)
plt.imshow(all_data_pca_pxtc)
plt.colorbar()
plt.subplot(312)
plt.imshow(np.dot(W_chxp.T, all_data_zm_chxtc))
plt.colorbar()
plt.subplot(313)
plt.imshow(np.dot(all_data_zm_chxtc.T, W_chxp).T - all_data_pca_pxtc)
plt.colorbar()
import pdb
pdb.set_trace()
return datasets
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function
import os
import h5py
import json
import numpy as np
import tensorflow as tf
def log_sum_exp(x_k):
"""Computes log \sum exp in a numerically stable way.
log ( sum_i exp(x_i) )
log ( sum_i exp(x_i - m + m) ), with m = max(x_i)
log ( sum_i exp(x_i - m)*exp(m) )
log ( sum_i exp(x_i - m) + m
Args:
x_k - k -dimensional list of arguments to log_sum_exp.
Returns:
log_sum_exp of the arguments.
"""
m = tf.reduce_max(x_k)
x1_k = x_k - m
u_k = tf.exp(x1_k)
z = tf.reduce_sum(u_k)
return tf.log(z) + m
def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
x: input The tensor to tranformation.
out_size: The integer size of non-batch output dimension.
do_bias (optional): Add a learnable bias vector to the operation.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the tensorflow op that yields y.
"""
in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
# Sometimes linear layers are nothing more than size adapters.
return tf.identity(x, name=(wname+'_ident'))
W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha,
normalized=normalized, name=name, collections=collections)
if do_bias:
return tf.matmul(x, W) + b
else:
return tf.matmul(x, W)
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
identity_if_possible=False, normalized=False,
name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation.
mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization
of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
identity_if_possible (optional): just return identity,
if x.shape[1] == out_size.
normalized (optional): Option to divide out by the norms of the rows of W.
name (optional): The name prefix to add to variables.
collections (optional): List of additional collections. (Placed in
tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
Returns:
In the equation, y = x W + b, returns the pair (W, b).
"""
if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
raise ValueError(
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size))
mat_init = tf.random_normal_initializer(0.0, stddev)
wname = (name + "/W") if name else "/W"
if identity_if_possible and in_size == out_size:
return (tf.constant(np.eye(in_size).astype(np.float32)),
tf.zeros(in_size))
# Note the use of get_variable vs. tf.Variable. this is because get_variable
# does not allow the initialization of the variable with a value.
if normalized:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
b_collections += collections
bname = (name + "/b") if name else "/b"
b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(),
collections=b_collections)
else:
b = None
return (w, b)
def write_data(data_fname, data_dict, use_json=False, compression=None):
"""Write data in HD5F format.
Args:
data_fname: The filename of teh file in which to write the data.
data_dict: The dictionary of data to write. The keys are strings
and the values are numpy arrays.
use_json (optional): human readable format for simple items
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
dir_name = os.path.dirname(data_fname)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
if use_json:
the_file = open(data_fname,'w')
json.dump(data_dict, the_file)
the_file.close()
else:
try:
with h5py.File(data_fname, 'w') as hf:
for k, v in data_dict.items():
clean_k = k.replace('/', '_')
if clean_k is not k:
print('Warning: saving variable with name: ', k, ' as ', clean_k)
else:
print('Saving variable with name: ', clean_k)
hf.create_dataset(clean_k, data=v, compression=compression)
except IOError:
print("Cannot open %s for writing.", data_fname)
raise
def read_data(data_fname):
""" Read saved data in HDF5 format.
Args:
data_fname: The filename of the file from which to read the data.
Returns:
A dictionary whose keys will vary depending on dataset (but should
always contain the keys 'train_data' and 'valid_data') and whose
values are numpy arrays.
"""
try:
with h5py.File(data_fname, 'r') as hf:
data_dict = {k: np.array(v) for k, v in hf.items()}
return data_dict
except IOError:
print("Cannot open %s for reading." % data_fname)
raise
def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None):
"""Write datasets in HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
dataset_dict: The dictionary of datasets. The keys are strings
and the values data dictionaries (str -> numpy arrays) associations.
compression (optional): The compression to use for h5py (disabled by
default because the library borks on scalars, otherwise try 'gzip').
"""
full_name_stem = os.path.join(data_path, data_fname_stem)
for s, data_dict in dataset_dict.items():
write_data(full_name_stem + "_" + s, data_dict, compression=compression)
def read_datasets(data_path, data_fname_stem):
"""Read dataset sin HD5F format.
This function assumes the dataset_dict is a mapping ( string ->
to data_dict ). It calls write_data for each data dictionary,
post-fixing the data filename with the key of the dataset.
Args:
data_path: The path to the save directory.
data_fname_stem: The filename stem of the file in which to write the data.
"""
dataset_dict = {}
fnames = os.listdir(data_path)
print ('loading data from ' + data_path + ' with stem ' + data_fname_stem)
for fname in fnames:
if fname.startswith(data_fname_stem):
data_dict = read_data(os.path.join(data_path,fname))
idx = len(data_fname_stem) + 1
key = fname[idx:]
data_dict['data_dim'] = data_dict['train_data'].shape[2]
data_dict['num_steps'] = data_dict['train_data'].shape[1]
dataset_dict[key] = data_dict
if len(dataset_dict) == 0:
raise ValueError("Failed to load any datasets, are you sure that the "
"'--data_dir' and '--data_filename_stem' flag values "
"are correct?")
print (str(len(dataset_dict)) + ' datasets loaded')
return dataset_dict
# NUMPY utility functions
def list_t_bxn_to_list_b_txn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors of length B list of TxN numpy
tensors.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
The length B list of TxN numpy tensors.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_b_txn = []
for b in range(B):
values_pb_txn = np.zeros([T,N])
for t in range(T):
values_pb_txn[t,:] = values_t_bxn[t][b,:]
values_b_txn.append(values_pb_txn)
return values_b_txn
def list_t_bxn_to_tensor_bxtxn(values_t_bxn):
"""Convert a length T list of BxN numpy tensors to single numpy tensor with
shape BxTxN.
Args:
values_t_bxn: The length T list of BxN numpy tensors.
Returns:
values_bxtxn: The BxTxN numpy tensor.
"""
T = len(values_t_bxn)
B, N = values_t_bxn[0].shape
values_bxtxn = np.zeros([B,T,N])
for t in range(T):
values_bxtxn[:,t,:] = values_t_bxn[t]
return values_bxtxn
def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn):
"""Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors
with shape BxT.
Args:
tensor_bxtxn: The BxTxN numpy tensor.
Returns:
A length T list of numpy tensors with shape BxT.
"""
values_t_bxn = []
B, T, N = tensor_bxtxn.shape
for t in range(T):
values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:]))
return values_t_bxn
def flatten(list_of_lists):
"""Takes a list of lists and returns a list of the elements.
Args:
list_of_lists: List of lists.
Returns:
flat_list: Flattened list.
flat_list_idxs: Flattened list indices.
"""
flat_list = []
flat_list_idxs = []
start_idx = 0
for item in list_of_lists:
if isinstance(item, list):
flat_list += item
l = len(item)
idxs = range(start_idx, start_idx+l)
start_idx = start_idx+l
else: # a value
flat_list.append(item)
idxs = [start_idx]
start_idx += 1
flat_list_idxs.append(idxs)
return flat_list, flat_list_idxs
......@@ -73,7 +73,7 @@ LSTM-8192-2048 (50\% Dropout) | 32.2 | 3.3
<b>How To Run</b>
Pre-requesite:
Prerequisites:
* Install TensorFlow.
* Install Bazel.
......@@ -97,7 +97,7 @@ Pre-requesite:
[link](http://download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt)
* test dataset: link
[link](http://download.tensorflow.org/models/LM_LSTM_CNN/test/news.en.heldout-00000-of-00050)
* It is recommended to run on modern desktop instead of laptop.
* It is recommended to run on a modern desktop instead of a laptop.
```shell
# 1. Clone the code to your workspace.
......@@ -105,7 +105,7 @@ Pre-requesite:
# 3. Create an empty WORKSPACE file in your workspace.
# 4. Create an empty output directory in your workspace.
# Example directory structure below:
ls -R
$ ls -R
.:
data lm_1b output WORKSPACE
......@@ -121,13 +121,13 @@ BUILD data_utils.py lm_1b_eval.py README.md
./output:
# Build the codes.
bazel build -c opt lm_1b/...
$ bazel build -c opt lm_1b/...
# Run sample mode:
bazel-bin/lm_1b/lm_1b_eval --mode sample \
--prefix "I love that I" \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*'
$ bazel-bin/lm_1b/lm_1b_eval --mode sample \
--prefix "I love that I" \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*'
...(omitted some TensorFlow output)
I love
I love that
......@@ -138,11 +138,11 @@ I love that I find that amazing
...(omitted)
# Run eval mode:
bazel-bin/lm_1b/lm_1b_eval --mode eval \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--input_data data/news.en.heldout-00000-of-00050 \
--ckpt 'data/ckpt-*'
$ bazel-bin/lm_1b/lm_1b_eval --mode eval \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--input_data data/news.en.heldout-00000-of-00050 \
--ckpt 'data/ckpt-*'
...(omitted some TensorFlow output)
Loaded step 14108582.
# perplexity is high initially because words without context are harder to
......@@ -166,28 +166,28 @@ Eval Step: 4531, Average Perplexity: 29.285674.
...(omitted. At convergence, it should be around 30.)
# Run dump_emb mode:
bazel-bin/lm_1b/lm_1b_eval --mode dump_emb \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*' \
--save_dir output
$ bazel-bin/lm_1b/lm_1b_eval --mode dump_emb \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*' \
--save_dir output
...(omitted some TensorFlow output)
Finished softmax weights
Finished word embedding 0/793471
Finished word embedding 1/793471
Finished word embedding 2/793471
...(omitted)
ls output/
$ ls output/
embeddings_softmax.npy ...
# Run dump_lstm_emb mode:
bazel-bin/lm_1b/lm_1b_eval --mode dump_lstm_emb \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*' \
--sentence "I love who I am ." \
--save_dir output
ls output/
$ bazel-bin/lm_1b/lm_1b_eval --mode dump_lstm_emb \
--pbtxt data/graph-2016-09-10.pbtxt \
--vocab_file data/vocab-2016-09-10.txt \
--ckpt 'data/ckpt-*' \
--sentence "I love who I am ." \
--save_dir output
$ ls output/
lstm_emb_step_0.npy lstm_emb_step_2.npy lstm_emb_step_4.npy
lstm_emb_step_6.npy lstm_emb_step_1.npy lstm_emb_step_3.npy
lstm_emb_step_5.npy
......
......@@ -19,6 +19,7 @@ import os
import sys
import numpy as np
from six.moves import xrange
import tensorflow as tf
from google.protobuf import text_format
......@@ -83,7 +84,7 @@ def _LoadModel(gd_file, ckpt_file):
with tf.Graph().as_default():
sys.stderr.write('Recovering graph.\n')
with tf.gfile.FastGFile(gd_file, 'r') as f:
s = f.read()
s = f.read().decode()
gd = tf.GraphDef()
text_format.Merge(s, gd)
......@@ -230,7 +231,7 @@ def _DumpEmb(vocab):
sys.stderr.write('Finished softmax weights\n')
all_embs = np.zeros([vocab.size, 1024])
for i in range(vocab.size):
for i in xrange(vocab.size):
input_dict = {t['inputs_in']: inputs,
t['targets_in']: targets,
t['target_weights_in']: weights}
......
# NeuralGPU
Code for the Neural GPU model described in [[http://arxiv.org/abs/1511.08228]].
The extended version was described in [[https://arxiv.org/abs/1610.08613]].
Code for the Neural GPU model described in http://arxiv.org/abs/1511.08228.
The extended version was described in https://arxiv.org/abs/1610.08613.
Requirements:
* TensorFlow (see tensorflow.org for how to install)
......
......@@ -478,8 +478,10 @@ class NeuralGPU(object):
# This is just for running a baseline RNN seq2seq model.
if do_rnn:
self.after_enc_step.append(step) # Not meaningful here, but needed.
lstm_cell = tf.contrib.rnn.BasicLSTMCell(height * nmaps)
cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * nconvs)
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(height * nmaps)
cell = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(nconvs)])
with tf.variable_scope("encoder"):
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
cell, tf.reshape(step, [batch_size, length, height * nmaps]),
......
......@@ -12,17 +12,11 @@ Authors: Xin Pan (Github: panyx0718), Anelia Angelova
<b>Results:</b>
<left>
![Sample1](g3doc/cross_conv.png)
</left>
<left>
![Sample2](g3doc/cross_conv2.png)
</left>
<left>
![Loss](g3doc/cross_conv3.png)
</left>
<b>Prerequisite:</b>
......@@ -40,7 +34,7 @@ to tf.SequenceExample.
<b>How to run:</b>
```shell
ls -R
$ ls -R
.:
data next_frame_prediction WORKSPACE
......@@ -58,18 +52,18 @@ cross_conv2.png cross_conv3.png cross_conv.png
# Build everything.
bazel build -c opt next_frame_prediction/...
$ bazel build -c opt next_frame_prediction/...
# The following example runs the generated 2d objects.
# For Sprites dataset, image_size should be 60, norm_scale should be 255.0.
# Batch size is normally 16~64, depending on your memory size.
#
# Run training.
bazel-bin/next_frame_prediction/cross_conv/train \
--batch_size=1 \
--data_filepattern=data/tfrecords \
--image_size=64 \
--log_root=/tmp/predict
$ bazel-bin/next_frame_prediction/cross_conv/train \
--batch_size=1 \
--data_filepattern=data/tfrecords \
--image_size=64 \
--log_root=/tmp/predict
step: 1, loss: 24.428671
step: 2, loss: 19.211605
......@@ -81,11 +75,11 @@ step: 7, loss: 1.747665
step: 8, loss: 1.572436
step: 9, loss: 1.586816
step: 10, loss: 1.434191
#
# Run eval.
bazel-bin/next_frame_prediction/cross_conv/eval \
--batch_size=1 \
--data_filepattern=data/tfrecords_test \
--image_size=64 \
--log_root=/tmp/predict
$ bazel-bin/next_frame_prediction/cross_conv/eval \
--batch_size=1 \
--data_filepattern=data/tfrecords_test \
--image_size=64 \
--log_root=/tmp/predict
```
# Tensorflow Object Detection API: main runnables.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_binary(
name = "train",
srcs = [
"train.py",
],
deps = [
":trainer",
"//tensorflow",
"//tensorflow_models/object_detection/builders:input_reader_builder",
"//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2",
"//tensorflow_models/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
"//tensorflow_models/object_detection/protos:train_py_pb2",
],
)
py_library(
name = "trainer",
srcs = ["trainer.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/builders:optimizer_builder",
"//tensorflow_models/object_detection/builders:preprocessor_builder",
"//tensorflow_models/object_detection/core:batcher",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:model_deploy",
],
)
py_test(
name = "trainer_test",
srcs = ["trainer_test.py"],
deps = [
":trainer",
"//tensorflow",
"//tensorflow_models/object_detection/core:losses",
"//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/protos:train_py_pb2",
],
)
py_library(
name = "eval_util",
srcs = [
"eval_util.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:label_map_util",
"//tensorflow_models/object_detection/utils:object_detection_evaluation",
"//tensorflow_models/object_detection/utils:visualization_utils",
],
)
py_library(
name = "evaluator",
srcs = ["evaluator.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection:eval_util",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_list_ops",
"//tensorflow_models/object_detection/core:prefetcher",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/protos:eval_py_pb2",
],
)
py_binary(
name = "eval",
srcs = [
"eval.py",
],
deps = [
":evaluator",
"//tensorflow",
"//tensorflow_models/object_detection/builders:input_reader_builder",
"//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/protos:eval_py_pb2",
"//tensorflow_models/object_detection/protos:input_reader_py_pb2",
"//tensorflow_models/object_detection/protos:model_py_pb2",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
py_library(
name = "exporter",
srcs = [
"exporter.py",
],
deps = [
"//tensorflow",
"//tensorflow/python/tools:freeze_graph_lib",
"//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/data_decoders:tf_example_decoder",
],
)
py_test(
name = "exporter_test",
srcs = [
"exporter_test.py",
],
deps = [
":exporter",
"//tensorflow",
"//tensorflow_models/object_detection/builders:model_builder",
"//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
],
)
py_binary(
name = "export_inference_graph",
srcs = [
"export_inference_graph.py",
],
deps = [
":exporter",
"//tensorflow",
"//tensorflow_models/object_detection/protos:pipeline_py_pb2",
],
)
py_binary(
name = "create_pascal_tf_record",
srcs = [
"create_pascal_tf_record.py",
],
deps = [
"//third_party/py/PIL:pil",
"//third_party/py/lxml",
"//tensorflow",
"//tensorflow_models/object_detection/utils:dataset_util",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
py_test(
name = "create_pascal_tf_record_test",
srcs = [
"create_pascal_tf_record_test.py",
],
deps = [
":create_pascal_tf_record",
"//tensorflow",
],
)
py_binary(
name = "create_pet_tf_record",
srcs = [
"create_pet_tf_record.py",
],
deps = [
"//third_party/py/PIL:pil",
"//third_party/py/lxml",
"//tensorflow",
"//tensorflow_models/object_detection/utils:dataset_util",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
# Contributing to the Tensorflow Object Detection API
Patches to Tensorflow Object Detection API are welcome!
We require contributors to fill out either the individual or corporate
Contributor License Agreement (CLA).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
Please follow the
[Tensorflow contributing guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md)
when submitting pull requests.
# Tensorflow Object Detection API
Creating accurate machine learning models capable of localizing and identifying
multiple objects in a single image remains a core challenge in computer vision.
The TensorFlow Object Detection API is an open source framework built on top of
TensorFlow that makes it easy to construct, train and deploy object detection
models. At Google we’ve certainly found this codebase to be useful for our
computer vision needs, and we hope that you will as well.
<p align="center">
<img src="g3doc/img/kites_detections_output.jpg" width=676 height=450>
</p>
Contributions to the codebase are welcome and we would love to hear back from
you if you find this API useful. Finally if you use the Tensorflow Object
Detection API for a research publication, please consider citing:
```
"Speed/accuracy trade-offs for modern convolutional object detectors."
Huang J, Rathod V, Sun C, Zhu M, Korattikara A, Fathi A, Fischer I, Wojna Z,
Song Y, Guadarrama S, Murphy K, CVPR 2017
```
\[[link](https://arxiv.org/abs/1611.10012)\]\[[bibtex](
https://scholar.googleusercontent.com/scholar.bib?q=info:l291WsrB-hQJ:scholar.google.com/&output=citation&scisig=AAGBfm0AAAAAWUIIlnPZ_L9jxvPwcC49kDlELtaeIyU-&scisf=4&ct=citation&cd=-1&hl=en&scfhb=1)\]
## Maintainers
* Jonathan Huang, github: [jch1](https://github.com/jch1)
* Vivek Rathod, github: [tombstone](https://github.com/tombstone)
* Derek Chow, github: [derekjchow](https://github.com/derekjchow)
* Chen Sun, github: [jesu9](https://github.com/jesu9)
* Menglong Zhu, github: [dreamdragon](https://github.com/dreamdragon)
## Table of contents
Quick Start:
* <a href='object_detection_tutorial.ipynb'>
Quick Start: Jupyter notebook for off-the-shelf inference</a><br>
* <a href="g3doc/running_pets.md">Quick Start: Training a pet detector</a><br>
Setup:
* <a href='g3doc/installation.md'>Installation</a><br>
* <a href='g3doc/configuring_jobs.md'>
Configuring an object detection pipeline</a><br>
* <a href='g3doc/preparing_inputs.md'>Preparing inputs</a><br>
Running:
* <a href='g3doc/running_locally.md'>Running locally</a><br>
* <a href='g3doc/running_on_cloud.md'>Running on the cloud</a><br>
Extras:
* <a href='g3doc/detection_model_zoo.md'>Tensorflow detection model zoo</a><br>
* <a href='g3doc/exporting_models.md'>
Exporting a trained model for inference</a><br>
* <a href='g3doc/defining_your_own_model.md'>
Defining your own model architecture</a><br>
## Release information
### June 15, 2017
In addition to our base Tensorflow detection model definitions, this
release includes:
* A selection of trainable detection models, including:
* Single Shot Multibox Detector (SSD) with MobileNet,
* SSD with Inception V2,
* Region-Based Fully Convolutional Networks (R-FCN) with Resnet 101,
* Faster RCNN with Resnet 101,
* Faster RCNN with Inception Resnet v2
* Frozen weights (trained on the COCO dataset) for each of the above models to
be used for out-of-the-box inference purposes.
* A [Jupyter notebook](object_detection_tutorial.ipynb) for performing
out-of-the-box inference with one of our released models
* Convenient [local training](g3doc/running_locally.md) scripts as well as
distributed training and evaluation pipelines via
[Google Cloud](g3doc/running_on_cloud.md).
<b>Thanks to contributors</b>: Jonathan Huang, Vivek Rathod, Derek Chow,
Chen Sun, Menglong Zhu, Matthew Tang, Anoop Korattikara, Alireza Fathi, Ian Fischer, Zbigniew Wojna, Yang Song, Sergio Guadarrama, Jasper Uijlings,
Viacheslav Kovalevskyi, Kevin Murphy
# Tensorflow Object Detection API: Anchor Generator implementations.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "grid_anchor_generator",
srcs = [
"grid_anchor_generator.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/core:anchor_generator",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/utils:ops",
],
)
py_test(
name = "grid_anchor_generator_test",
srcs = [
"grid_anchor_generator_test.py",
],
deps = [
":grid_anchor_generator",
"//tensorflow",
],
)
py_library(
name = "multiple_grid_anchor_generator",
srcs = [
"multiple_grid_anchor_generator.py",
],
deps = [
":grid_anchor_generator",
"//tensorflow",
"//tensorflow_models/object_detection/core:anchor_generator",
"//tensorflow_models/object_detection/core:box_list_ops",
],
)
py_test(
name = "multiple_grid_anchor_generator_test",
srcs = [
"multiple_grid_anchor_generator_test.py",
],
deps = [
":multiple_grid_anchor_generator",
"//third_party/py/numpy",
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment