"docs/vscode:/vscode.git/clone" did not exist on "29903c5c41a70ac7cf20c91a1af926464c405c0d"
Commit b1354256 authored by Dan O'Shea's avatar Dan O'Shea
Browse files

Formatting fixes

parent 32afad9c
...@@ -366,9 +366,11 @@ class LFADS(object): ...@@ -366,9 +366,11 @@ class LFADS(object):
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
if hps.do_train_readin: if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix provided for dataset:", name) print("Initializing trainable readin matrix with alignment matrix \
provided for dataset:", name)
else: else:
print("Setting non-trainable readin matrix to alignment matrix provided for dataset:", name) print("Setting non-trainable readin matrix to alignment matrix \
provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
...@@ -378,9 +380,11 @@ class LFADS(object): ...@@ -378,9 +380,11 @@ class LFADS(object):
if datasets and 'alignment_bias_c' in datasets[name].keys(): if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
if hps.do_train_readin: if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias provided for dataset:", name) print("Initializing trainable readin bias with alignment bias \
provided for dataset:", name)
else: else:
print("Setting non-trainable readin bias to alignment bias provided for dataset:", name) print("Setting non-trainable readin bias to alignment bias \
provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim: if align_bias_1xc.shape[1] != data_dim:
......
...@@ -320,10 +320,13 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE, ...@@ -320,10 +320,13 @@ flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
# for multi-session "stitching" models, the per-session readin matrices map from # for multi-session "stitching" models, the per-session readin matrices map from
# neurons to input factors which are fed into the shared encoder. These are initialized # neurons to input factors which are fed into the shared encoder. These are
# by alignment_matrix_cxf and alignment_bias_c in the input .h5 files. They can be fixed or # initialized by alignment_matrix_cxf and alignment_bias_c in the input .h5
# made trainable. # files. They can be fixed or made trainable.
flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the readin matrices and bias vectors. False leaves them fixed at their initial values specified by the alignment matrices / vectors.") flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \
readin matrices and bias vectors. False leaves them fixed \
at their initial values specified by the alignment \
matrices and vectors.")
# OVERFITTING # OVERFITTING
...@@ -440,7 +443,8 @@ def build_model(hps, kind="train", datasets=None): ...@@ -440,7 +443,8 @@ def build_model(hps, kind="train", datasets=None):
print("Possible error!!! You are running ", kind, " on a newly \ print("Possible error!!! You are running ", kind, " on a newly \
initialized model!") initialized model!")
# cant print ckpt.model_check_point path if no ckpt # cant print ckpt.model_check_point path if no ckpt
print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir, " exists?") print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
" exists?")
tf.global_variables_initializer().run() tf.global_variables_initializer().run()
...@@ -787,4 +791,3 @@ def main(_): ...@@ -787,4 +791,3 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run()
...@@ -132,7 +132,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -132,7 +132,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections, trainable=trainable) w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections, trainable=trainable) collections=w_collections, trainable=trainable)
...@@ -142,7 +143,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -142,7 +143,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
if collections: if collections:
w_collections += collections w_collections += collections
if mat_init_value is not None: if mat_init_value is not None:
w = tf.Variable(mat_init_value, name=wname, collections=w_collections, trainable=trainable) w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
trainable=trainable)
else: else:
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections, trainable=trainable) collections=w_collections, trainable=trainable)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment