Unverified Commit 9ff6d51b authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #3314 from djoshea/do_train_readin

lfads: Adding do_train_readin option allowing fixed readin matrices
parents 079d67d9 4070fa63
...@@ -365,7 +365,12 @@ class LFADS(object): ...@@ -365,7 +365,12 @@ class LFADS(object):
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name) if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix" \
" provided for dataset:", name)
else:
print("Setting non-trainable readin matrix to alignment matrix" \
" provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
...@@ -374,7 +379,12 @@ class LFADS(object): ...@@ -374,7 +379,12 @@ class LFADS(object):
in_mat_cxf.shape[1])) in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys(): if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment bias provided for dataset:", name) if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias " \
"provided for dataset:", name)
else:
print("Setting non-trainable readin bias to alignment bias " \
"provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim: if align_bias_1xc.shape[1] != data_dim:
...@@ -387,12 +397,22 @@ class LFADS(object): ...@@ -387,12 +397,22 @@ class LFADS(object):
# So b = -alignment_bias * W_in to accommodate PCA style offset. # So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True, if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be
# learnable, because IO_transformations collection will be trained
# when do_train_io_only
collections_readin=['IO_transformations']
else:
collections_readin=None
in_fac_lin = init_linear(data_dim, used_in_factors_dim,
do_bias=True,
mat_init_value=in_mat_cxf, mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf, bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss, identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name, normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations']) collections=collections_readin,
trainable=hps.do_train_readin)
in_fac_W, in_fac_b = in_fac_lin in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W) fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b) fns_in_fac_bs[d] = makelambda(in_fac_b)
...@@ -417,7 +437,7 @@ class LFADS(object): ...@@ -417,7 +437,7 @@ class LFADS(object):
out_mat_fxc = None out_mat_fxc = None
out_bias_1xc = None out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
out_mat_fxc = np.linalg.pinv(in_mat_cxf) out_mat_fxc = in_mat_cxf.T
if align_bias_1xc is not None: if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc out_bias_1xc = align_bias_1xc
......
...@@ -23,6 +23,8 @@ import os ...@@ -23,6 +23,8 @@ import os
import tensorflow as tf import tensorflow as tf
import re import re
import utils import utils
import sys
MAX_INT = sys.maxsize
# Lots of hyperparameters, but most are pretty insensitive. The # Lots of hyperparameters, but most are pretty insensitive. The
# explanation of these hyperparameters is found below, in the flags # explanation of these hyperparameters is found below, in the flags
...@@ -35,7 +37,7 @@ OUTPUT_FILENAME_STEM = "" ...@@ -35,7 +37,7 @@ OUTPUT_FILENAME_STEM = ""
DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1" DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1"
MAX_CKPT_TO_KEEP = 5 MAX_CKPT_TO_KEEP = 5
MAX_CKPT_TO_KEEP_LVE = 5 MAX_CKPT_TO_KEEP_LVE = 5
PS_NEXAMPLES_TO_PROCESS = 1e8 # if larger than number of examples, process all PS_NEXAMPLES_TO_PROCESS = MAX_INT # if larger than number of examples, process all
EXT_INPUT_DIM = 0 EXT_INPUT_DIM = 0
IC_DIM = 64 IC_DIM = 64
FACTORS_DIM = 50 FACTORS_DIM = 50
...@@ -53,6 +55,7 @@ INJECT_EXT_INPUT_TO_GEN = False ...@@ -53,6 +55,7 @@ INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False DO_TRAIN_IO_ONLY = False
DO_RESET_LEARNING_RATE = False DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors" FEEDBACK_FACTORS_OR_RATES = "factors"
DO_TRAIN_READIN = True
# Calibrated just above the average value for the rnn synthetic data. # Calibrated just above the average value for the rnn synthetic data.
MAX_GRAD_NORM = 200.0 MAX_GRAD_NORM = 200.0
...@@ -60,7 +63,7 @@ CELL_CLIP_VALUE = 5.0 ...@@ -60,7 +63,7 @@ CELL_CLIP_VALUE = 5.0
KEEP_PROB = 0.95 KEEP_PROB = 0.95
TEMPORAL_SPIKE_JITTER_WIDTH = 0 TEMPORAL_SPIKE_JITTER_WIDTH = 0
OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian' OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian'
NUM_STEPS_FOR_GEN_IC = np.inf # set to num_steps if greater than num_steps NUM_STEPS_FOR_GEN_IC = MAX_INT # set to num_steps if greater than num_steps
DATA_DIR = "/tmp/rnn_synth_data_v1.0/" DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5" DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
...@@ -316,6 +319,16 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE, ...@@ -316,6 +319,16 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.") "Reset the learning rate to initial value.")
# for multi-session "stitching" models, the per-session readin matrices map from
# neurons to input factors which are fed into the shared encoder. These are
# initialized by alignment_matrix_cxf and alignment_bias_c in the input .h5
# files. They can be fixed or made trainable.
flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \
readin matrices and bias vectors. False leaves them fixed \
at their initial values specified by the alignment \
matrices and vectors.")
# OVERFITTING # OVERFITTING
# Dropout is done on the input data, on controller inputs (from # Dropout is done on the input data, on controller inputs (from
# encoder), on outputs from generator to factors. # encoder), on outputs from generator to factors.
...@@ -429,7 +442,9 @@ def build_model(hps, kind="train", datasets=None): ...@@ -429,7 +442,9 @@ def build_model(hps, kind="train", datasets=None):
"write_model_params"]: "write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \ print("Possible error!!! You are running ", kind, " on a newly \
initialized model!") initialized model!")
print("Are you sure you sure ", ckpt.model_checkpoint_path, " exists?") # cant print ckpt.model_check_point path if no ckpt
print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
" exists?")
tf.global_variables_initializer().run() tf.global_variables_initializer().run()
...@@ -536,6 +551,7 @@ def build_hyperparameter_dict(flags): ...@@ -536,6 +551,7 @@ def build_hyperparameter_dict(flags):
d['cell_clip_value'] = flags.cell_clip_value d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only d['do_train_io_only'] = flags.do_train_io_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate d['do_reset_learning_rate'] = flags.do_reset_learning_rate
d['do_train_readin'] = flags.do_train_readin
# Overfitting # Overfitting
d['keep_prob'] = flags.keep_prob d['keep_prob'] = flags.keep_prob
...@@ -775,4 +791,3 @@ def main(_): ...@@ -775,4 +791,3 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run()
...@@ -84,14 +84,15 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False, ...@@ -84,14 +84,15 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
bias_init_value=None, alpha=1.0, identity_if_possible=False, bias_init_value=None, alpha=1.0, identity_if_possible=False,
normalized=False, name=None, collections=None): normalized=False, name=None, collections=None, trainable=True):
"""Linear (affine) transformation, y = x W + b, for a variety of """Linear (affine) transformation, y = x W + b, for a variety of
configurations. configurations.
Args: Args:
in_size: The integer size of the non-batc input dimension. [(x),y] in_size: The integer size of the non-batc input dimension. [(x),y]
out_size: The integer size of non-batch output dimension. [x,(y)] out_size: The integer size of non-batch output dimension. [x,(y)]
do_bias (optional): Add a learnable bias vector to the operation. do_bias (optional): Add a (learnable) bias vector to the operation,
if false, b will be None
mat_init_value (optional): numpy constant for matrix initialization, if None mat_init_value (optional): numpy constant for matrix initialization, if None
, do random, with additional parameters. , do random, with additional parameters.
alpha (optional): A multiplicative scaling for the weight initialization alpha (optional): A multiplicative scaling for the weight initialization
...@@ -131,21 +132,22 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -131,21 +132,22 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections) w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections, trainable=trainable)
w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
else: else:
w_collections = [tf.GraphKeys.GLOBAL_VARIABLES] w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections) w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections, trainable=trainable)
b = None b = None
if do_bias: if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES] b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
...@@ -155,11 +157,12 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -155,11 +157,12 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if bias_init_value is None: if bias_init_value is None:
b = tf.get_variable(bname, [1, out_size], b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(), initializer=tf.zeros_initializer(),
collections=b_collections) collections=b_collections,
trainable=trainable)
else: else:
b = tf.Variable(bias_init_value, name=bname, b = tf.Variable(bias_init_value, name=bname,
collections=b_collections) collections=b_collections,
trainable=trainable)
return (w, b) return (w, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment