Unverified Commit 3022f945 authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #2898 from cpandar/master

change to lfads to allow training of encoder weights only
parents 99400da5 9b3a7754
...@@ -915,13 +915,25 @@ class LFADS(object): ...@@ -915,13 +915,25 @@ class LFADS(object):
return return
# OPTIMIZATION # OPTIMIZATION
if not self.hps.do_train_io_only: # train the io matrices only
if self.hps.do_train_io_only:
self.train_vars = tvars = \ self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name) scope=tf.get_variable_scope().name)
# train the encoder only
elif self.hps.do_train_encoder_only:
tvars1 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/ic_enc_*')
tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2
# train all variables
else: else:
self.train_vars = tvars = \ self.train_vars = tvars = \
tf.get_collection('IO_transformations', tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name) scope=tf.get_variable_scope().name)
print("done.") print("done.")
print("Model Variables (to be optimized): ") print("Model Variables (to be optimized): ")
......
...@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001 ...@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE = 6 LEARNING_RATE_N_TO_COMPARE = 6
INJECT_EXT_INPUT_TO_GEN = False INJECT_EXT_INPUT_TO_GEN = False
DO_TRAIN_IO_ONLY = False DO_TRAIN_IO_ONLY = False
DO_TRAIN_ENCODER_ONLY = False
DO_RESET_LEARNING_RATE = False DO_RESET_LEARNING_RATE = False
FEEDBACK_FACTORS_OR_RATES = "factors" FEEDBACK_FACTORS_OR_RATES = "factors"
DO_TRAIN_READIN = True DO_TRAIN_READIN = True
...@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY, ...@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout) \ "Train only the input (readin) and output (readout) \
affine functions.") affine functions.")
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
"Train only the encoder weights.")
flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE, flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
"Reset the learning rate to initial value.") "Reset the learning rate to initial value.")
...@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags): ...@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags):
d['max_grad_norm'] = flags.max_grad_norm d['max_grad_norm'] = flags.max_grad_norm
d['cell_clip_value'] = flags.cell_clip_value d['cell_clip_value'] = flags.cell_clip_value
d['do_train_io_only'] = flags.do_train_io_only d['do_train_io_only'] = flags.do_train_io_only
d['do_train_encoder_only'] = flags.do_train_encoder_only
d['do_reset_learning_rate'] = flags.do_reset_learning_rate d['do_reset_learning_rate'] = flags.do_reset_learning_rate
d['do_train_readin'] = flags.do_train_readin d['do_train_readin'] = flags.do_train_readin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment