Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3022f945
Unverified
Commit
3022f945
authored
Feb 09, 2018
by
David Sussillo
Committed by
GitHub
Feb 09, 2018
Browse files
Merge pull request #2898 from cpandar/master
change to lfads to allow training of encoder weights only
parents
99400da5
9b3a7754
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
3 deletions
+27
-3
research/lfads/lfads.py
research/lfads/lfads.py
+15
-3
research/lfads/run_lfads.py
research/lfads/run_lfads.py
+12
-0
No files found.
research/lfads/lfads.py
View file @
3022f945
...
...
@@ -915,13 +915,25 @@ class LFADS(object):
return
# OPTIMIZATION
if
not
self
.
hps
.
do_train_io_only
:
# train the io matrices only
if
self
.
hps
.
do_train_io_only
:
self
.
train_vars
=
tvars
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
get_collection
(
'IO_transformations'
,
scope
=
tf
.
get_variable_scope
().
name
)
# train the encoder only
elif
self
.
hps
.
do_train_encoder_only
:
tvars1
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/ic_enc_*'
)
tvars2
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/z/ic_enc_*'
)
self
.
train_vars
=
tvars
=
tvars1
+
tvars2
# train all variables
else
:
self
.
train_vars
=
tvars
=
\
tf
.
get_collection
(
'IO_transformations'
,
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
tf
.
get_variable_scope
().
name
)
print
(
"done."
)
print
(
"Model Variables (to be optimized): "
)
...
...
research/lfads/run_lfads.py
View file @
3022f945
...
...
@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE
=
6
INJECT_EXT_INPUT_TO_GEN
=
False
DO_TRAIN_IO_ONLY
=
False
DO_TRAIN_ENCODER_ONLY
=
False
DO_RESET_LEARNING_RATE
=
False
FEEDBACK_FACTORS_OR_RATES
=
"factors"
DO_TRAIN_READIN
=
True
...
...
@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout)
\
affine functions."
)
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags
.
DEFINE_boolean
(
"do_train_encoder_only"
,
DO_TRAIN_ENCODER_ONLY
,
"Train only the encoder weights."
)
flags
.
DEFINE_boolean
(
"do_reset_learning_rate"
,
DO_RESET_LEARNING_RATE
,
"Reset the learning rate to initial value."
)
...
...
@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags):
d
[
'max_grad_norm'
]
=
flags
.
max_grad_norm
d
[
'cell_clip_value'
]
=
flags
.
cell_clip_value
d
[
'do_train_io_only'
]
=
flags
.
do_train_io_only
d
[
'do_train_encoder_only'
]
=
flags
.
do_train_encoder_only
d
[
'do_reset_learning_rate'
]
=
flags
.
do_reset_learning_rate
d
[
'do_train_readin'
]
=
flags
.
do_train_readin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment