Unverified Commit 9ff6d51b authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #3314 from djoshea/do_train_readin

lfads: Adding do_train_readin option allowing fixed readin matrices
parents 079d67d9 4070fa63
......@@ -365,7 +365,12 @@ class LFADS(object):
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name)
if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix" \
" provided for dataset:", name)
else:
print("Setting non-trainable readin matrix to alignment matrix" \
" provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d
......@@ -374,7 +379,12 @@ class LFADS(object):
in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment bias provided for dataset:", name)
if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias " \
"provided for dataset:", name)
else:
print("Setting non-trainable readin bias to alignment bias " \
"provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
......@@ -387,12 +397,22 @@ class LFADS(object):
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True,
if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be
# learnable, because IO_transformations collection will be trained
# when do_train_io_only
collections_readin=['IO_transformations']
else:
collections_readin=None
in_fac_lin = init_linear(data_dim, used_in_factors_dim,
do_bias=True,
mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations'])
collections=collections_readin,
trainable=hps.do_train_readin)
in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b)
......@@ -417,7 +437,7 @@ class LFADS(object):
out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None:
out_mat_fxc = np.linalg.pinv(in_mat_cxf)
out_mat_fxc = in_mat_cxf.T
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
......
......@@ -23,6 +23,8 @@ import os
import tensorflow as tf
import re
import utils
import sys
MAX_INT = sys.maxsize
# Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags
......@@ -35,7 +37,7 @@ OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all
PS_NEXAMPLES_TO_PROCESS = MAX_INT # if larger than number of examples, process all
EXT_INPUT_DIM = 0
IC_DIM = 64
FACTORS_DIM = 50
......@@ -53,6 +55,7 @@ INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
DO_TRAIN_READIN = True
# Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0
......@@ -60,7 +63,7 @@ CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps
NUM_STEPS_FOR_GEN_IC = MAX_INT # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
......@@ -209,9 +212,9 @@ flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
"Variance of control input prior distribution.")
flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
"Initial autocorrelation of AR(1) priors.")
flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
"Initial noise variance for AR(1) priors.")
flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
"Is the value for atau an init, or the constant value?")
......@@ -254,13 +257,13 @@ flags.DEFINE_boolean("do_causal_controller",
# Strictly speaking, feeding either the factors or the rates to the controller
# violates causality, since the g0 gets to see all the data. This may or may not
# be only a theoretical concern.
flags.DEFINE_boolean("do_feed_factors_to_controller",
DO_FEED_FACTORS_TO_CONTROLLER,
flags.DEFINE_boolean("do_feed_factors_to_controller",
DO_FEED_FACTORS_TO_CONTROLLER,
"Should factors[t-1] be input to controller at time t?")
flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
"Feedback the factors or the rates to the controller? \
Acceptable values: 'factors' or 'rates'.")
flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
"Time lag on the encoding to controller t-lag for \
forward, t+lag for reverse.")
......@@ -316,6 +319,16 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
# for multi-session "stitching" models, the per-session readin matrices map from
# neurons to input factors which are fed into the shared encoder. These are
# initialized by alignment_matrix_cxf and alignment_bias_c in the input .h5
# files. They can be fixed or made trainable.
flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \
readin matrices and bias vectors. False leaves them fixed \
at their initial values specified by the alignment \
matrices and vectors.")
# OVERFITTING
# Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors.
......@@ -429,7 +442,9 @@ def build_model(hps, kind="train", datasets=None):
"write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \
initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?")
# cant print ckpt.model_check_point path if no ckpt
print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
" exists?")
tf.global_variables_initializer().run()
......@@ -451,7 +466,7 @@ def jsonify_dict(d):
Creates a shallow-copied dictionary first, then accomplishes string
conversion.
Args:
Args:
d: hyperparameter dictionary
Returns: hyperparameter dictionary with bool's as strings
......@@ -536,6 +551,7 @@ def build_hyperparameter_dict(flags):
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
d['do_train_readin'] = flags.do_train_readin
# Overfitting
d['keep_prob'] = flags.keep_prob
......@@ -659,7 +675,7 @@ def write_model_parameters(hps, output_fname=None, datasets=None):
fname = os.path.join(hps.lfads_save_dir, output_fname)
print("Writing model parameters to: ", fname)
# save the optimizer params as well
model = build_model(hps, kind="write_model_params", datasets=datasets)
model = build_model(hps, kind="write_model_params", datasets=datasets)
model_params = model.eval_model_parameters(use_nested=False,
include_strs="LFADS")
utils.write_data(fname, model_params, compression=None)
......@@ -775,4 +791,3 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
......@@ -84,14 +84,15 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
bias_init_value=None, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None):
normalized=False, name=None, collections=None, trainable=True):
"""Linear (affine) transformation, y = x W + b, for a variety of
configurations.
Args:
in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation.
do_bias (optional): Add a (learnable) bias vector to the operation,
if false, b will be None
mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization
......@@ -131,21 +132,22 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
collections=w_collections, trainable=trainable)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
w_collections += collections
if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections)
w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections)
collections=w_collections, trainable=trainable)
b = None
if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
......@@ -155,11 +157,12 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if bias_init_value is None:
b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(),
collections=b_collections)
collections=b_collections,
trainable=trainable)
else:
b = tf.Variable(bias_init_value, name=bname,
collections=b_collections)
collections=b_collections,
trainable=trainable)
return (w, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment