"vscode:/vscode.git/clone" did not exist on "834b09c4e8049d927a29b46435f69db5ab9cd8a5"
Commit 9b3a7754 authored by Chethan Pandarinath's avatar Chethan Pandarinath
Browse files

change to lfads to allow training of encoder weights only

parent d6099f10
......@@ -895,13 +895,25 @@ class LFADS(object):
return
# OPTIMIZATION
if not self.hps.do_train_io_only:
# train the io matrices only
if self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
# train the encoder only
elif self.hps.do_train_encoder_only:
tvars1 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/ic_enc_*')
tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2
# train all variables
else:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
......
......@@ -51,6 +51,7 @@ LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_TRAIN_ENCODER_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
......@@ -312,6 +313,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \
affine functions.")
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
"Train only the encoder weights.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
......@@ -535,6 +546,7 @@ def build_hyperparameter_dict(flags):
d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_train_encoder_only'] = flags.do_train_encoder_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
# Overfitting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment