"...resnet50_tensorflow.git" did not exist on "8e8f37e598f8676bb041c2887152ce5dc42e4f7b"
Commit 32afad9c authored by Dan O'Shea's avatar Dan O'Shea
Browse files

lfads: Adding do_train_readin option defaulting to true, allowing fixed readin matrices

For stitched models, the readin matrices and bias vectors are initialized to the
"alignment" matrix and bias specified in each dataset's .h5 file. If do_train_readin is
True, these will be trainable, and if not, they will be fixed.
parent 31adae53
...@@ -365,7 +365,10 @@ class LFADS(object): ...@@ -365,7 +365,10 @@ class LFADS(object):
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name) if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix provided for dataset:", name)
else:
print("Setting non-trainable readin matrix to alignment matrix provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
...@@ -374,7 +377,10 @@ class LFADS(object): ...@@ -374,7 +377,10 @@ class LFADS(object):
in_mat_cxf.shape[1])) in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys(): if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment bias provided for dataset:", name) if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias provided for dataset:", name)
else:
print("Setting non-trainable readin bias to alignment bias provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim: if align_bias_1xc.shape[1] != data_dim:
...@@ -387,12 +393,20 @@ class LFADS(object): ...@@ -387,12 +393,20 @@ class LFADS(object):
# So b = -alignment_bias * W_in to accommodate PCA style offset. # So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True, if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be learnable, because IO_transformations collection will be trained when do_train_io_only
collections_readin=['IO_transformations']
else:
collections_readin=None
in_fac_lin = init_linear(data_dim, used_in_factors_dim,
do_bias=True,
mat_init_value=in_mat_cxf, mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf, bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss, identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name, normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations']) collections=collections_readin,
trainable=hps.do_train_readin)
in_fac_W, in_fac_b = in_fac_lin in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W) fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b) fns_in_fac_bs[d] = makelambda(in_fac_b)
...@@ -417,7 +431,7 @@ class LFADS(object): ...@@ -417,7 +431,7 @@ class LFADS(object):
out_mat_fxc = None out_mat_fxc = None
out_bias_1xc = None out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
out_mat_fxc = np.linalg.pinv(in_mat_cxf) out_mat_fxc = in_mat_cxf.T
if align_bias_1xc is not None: if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc out_bias_1xc = align_bias_1xc
......
...@@ -23,6 +23,8 @@ import os ...@@ -23,6 +23,8 @@ import os
import tensorflow as tf import tensorflow as tf
import re import re
import utils import utils
import sys
MAX_INT = sys.maxsize
# Lots of hyperparameters, but most are pretty insensitive. The # Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags # explanation of these hyperparameters is found below, in the flags
...@@ -35,7 +37,7 @@ OUTPUT_FILENAME_STEM = "" ...@@ -35,7 +37,7 @@ OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1" DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5 MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5 MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all PS_NEXAMPLES_TO_PROCESS = MAX_INT # if larger than number of examples, process all
EXT_INPUT_DIM = 0 EXT_INPUT_DIM = 0
IC_DIM = 64 IC_DIM = 64
FACTORS_DIM = 50 FACTORS_DIM = 50
...@@ -53,6 +55,7 @@ INJECT_EXT_INPUT_TO_GEN = False ...@@ -53,6 +55,7 @@ INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors" FEEDBACK_FACTORS_OR_RATES = "factors"
DO_TRAIN_READIN = True
# Calibrated just above the average value for the rnn synthetic data. # Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0 MAX_GRAD_NORM = 200.0
...@@ -60,7 +63,7 @@ CELL_CLIP_VALUE = 5.0 ...@@ -60,7 +63,7 @@ CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95 KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0 TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian' OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps NUM_STEPS_FOR_GEN_IC = MAX_INT # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/" DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5" DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
...@@ -316,6 +319,13 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE, ...@@ -316,6 +319,13 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.") "Reset the learning rate to initial value.")
# for multi-session "stitching" models, the per-session readin matrices map from
# neurons to input factors which are fed into the shared encoder. These are initialized
# by alignment_matrix_cxf and alignment_bias_c in the input .h5 files. They can be fixed or
# made trainable.
flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the readin matrices and bias vectors. False leaves them fixed at their initial values specified by the alignment matrices / vectors.")
# OVERFITTING # OVERFITTING
# Dropout is done on the input data, on controller inputs (from # Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors. # encoder), on outputs from generator to factors.
...@@ -429,7 +439,8 @@ def build_model(hps, kind="train", datasets=None): ...@@ -429,7 +439,8 @@ def build_model(hps, kind="train", datasets=None):
"write_model_params"]: "write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \ print("Possible error!!! You are running ", kind, " on a newly \
initialized model!") initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?") # cant print ckpt.model_check_point path if no ckpt
print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir, " exists?")
tf.global_variables_initializer().run() tf.global_variables_initializer().run()
...@@ -536,6 +547,7 @@ def build_hyperparameter_dict(flags): ...@@ -536,6 +547,7 @@ def build_hyperparameter_dict(flags):
d['cell_clip_value'] = flags.cell_clip_value d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate d['do_reset_learning_rate'] = flags.do_reset_learning_rate
d['do_train_readin'] = flags.do_train_readin
# Overfitting # Overfitting
d['keep_prob'] = flags.keep_prob d['keep_prob'] = flags.keep_prob
...@@ -659,7 +671,7 @@ def write_model_parameters(hps, output_fname=None, datasets=None): ...@@ -659,7 +671,7 @@ def write_model_parameters(hps, output_fname=None, datasets=None):
fname = os.path.join(hps.lfads_save_dir, output_fname) fname = os.path.join(hps.lfads_save_dir, output_fname)
print("Writing model parameters to: ", fname) print("Writing model parameters to: ", fname)
# save the optimizer params as well # save the optimizer params as well
model = build_model(hps, kind="write_model_params", datasets=datasets) model = build_model(hps, kind="write_model_params", datasets=datasets)
model_params = model.eval_model_parameters(use_nested=False, model_params = model.eval_model_parameters(use_nested=False,
include_strs="LFADS") include_strs="LFADS")
utils.write_data(fname, model_params, compression=None) utils.write_data(fname, model_params, compression=None)
......
...@@ -84,14 +84,15 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False, ...@@ -84,14 +84,15 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
bias_init_value=None, alpha=1.0, identity_if_possible=False, bias_init_value=None, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None): normalized=False, name=None, collections=None, trainable=True):
"""Linear (affine) transformation, y = x W + b, for a variety of """Linear (affine) transformation, y = x W + b, for a variety of
configurations. configurations.
Args: Args:
in_size: The integer size of the non-batc input dimension. [(x),y] in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)] out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation. do_bias (optional): Add a (learnable) bias vector to the operation,
if false, b will be an appropriately sized, non-trainable vector
mat_init_value (optional): numpy constant for matrix initialization, if None mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters. , do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization alpha (optional): A multiplicative scaling for the weight initialization
...@@ -131,35 +132,37 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -131,35 +132,37 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections) w = tf.Variable(mat_init_value, name=wname, collections=w_collections, trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections, trainable=trainable)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else: else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES] w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections) w = tf.Variable(mat_init_value, name=wname, collections=w_collections, trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections, trainable=trainable)
b = None b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
b_collections += collections
bname = (name + "/b") if name else "/b"
if do_bias: if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections:
b_collections += collections
bname = (name + "/b") if name else "/b"
if bias_init_value is None: if bias_init_value is None:
b = tf.get_variable(bname, [1, out_size], b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(), initializer=tf.zeros_initializer(),
collections=b_collections) collections=b_collections, trainable=trainable)
else: else:
b = tf.Variable(bias_init_value, name=bname, b = tf.Variable(bias_init_value, name=bname,
collections=b_collections) collections=b_collections, trainable=trainable)
else:
# construct a non-learnable vector of zeros as the bias
b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(), trainable=False)
return (w, b) return (w, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment