Unverified Commit 3022f945 authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #2898 from cpandar/master

change to lfads to allow training of encoder weights only
parents 99400da5 9b3a7754
......@@ -915,13 +915,25 @@ class LFADS(object):
return
# OPTIMIZATION
if not self.hps.do_train_io_only:
# train the io matrices only
if self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
# train the encoder only
elif self.hps.do_train_encoder_only:
tvars1 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/ic_enc_*')
tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2
# train all variables
else:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
......
......@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False
DO_TRAIN_ENCODER_ONLY = False
DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors"
DO_TRAIN_READIN = True
......@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \
affine functions.")
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
"Train only the encoder weights.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.")
......@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags):
d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only
d['do_train_encoder_only'] = flags.do_train_encoder_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate
d['do_train_readin'] = flags.do_train_readin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment