Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
74e35eda
Unverified
Commit
74e35eda
authored
Jun 18, 2018
by
David Sussillo
Committed by
GitHub
Jun 18, 2018
Browse files
Merge pull request #4570 from lfads/posterior_push_mean
LFADS: Support for posterior_push_mean instead of sampling
parents
2206bf54
20da056d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
204 additions
and
33 deletions
+204
-33
research/lfads/lfads.py
research/lfads/lfads.py
+182
-20
research/lfads/run_lfads.py
research/lfads/run_lfads.py
+22
-13
No files found.
research/lfads/lfads.py
View file @
74e35eda
...
...
@@ -295,7 +295,8 @@ class LFADS(object):
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
print
(
"Building graph..."
)
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'prior_sample'
]
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'posterior_push_mean'
,
'prior_sample'
]
assert
kind
in
all_kinds
,
'Wrong kind'
if
hps
.
feedback_factors_or_rates
==
"rates"
:
assert
len
(
hps
.
dataset_names
)
==
1
,
\
...
...
@@ -489,15 +490,10 @@ class LFADS(object):
pf_pairs_out_fac_Ws
=
zip
(
preds
,
fns_out_fac_Ws
)
pf_pairs_out_fac_bs
=
zip
(
preds
,
fns_out_fac_bs
)
def
_case_with_no_default
(
pairs
):
def
_default_value_fn
():
with
tf
.
control_dependencies
([
tf
.
Assert
(
False
,
[
"Reached default"
])]):
return
tf
.
identity
(
pairs
[
0
][
1
]())
return
tf
.
case
(
pairs
,
_default_value_fn
,
exclusive
=
True
)
this_in_fac_W
=
_case_with_no_default
(
pf_pairs_in_fac_Ws
)
this_in_fac_b
=
_case_with_no_default
(
pf_pairs_in_fac_bs
)
this_out_fac_W
=
_case_with_no_default
(
pf_pairs_out_fac_Ws
)
this_out_fac_b
=
_case_with_no_default
(
pf_pairs_out_fac_bs
)
this_in_fac_W
=
tf
.
case
(
pf_pairs_in_fac_Ws
,
exclusive
=
True
)
this_in_fac_b
=
tf
.
case
(
pf_pairs_in_fac_bs
,
exclusive
=
True
)
this_out_fac_W
=
tf
.
case
(
pf_pairs_out_fac_Ws
,
exclusive
=
True
)
this_out_fac_b
=
tf
.
case
(
pf_pairs_out_fac_bs
,
exclusive
=
True
)
# External inputs (not changing by dataset, by definition).
if
hps
.
ext_input_dim
>
0
:
...
...
@@ -622,7 +618,8 @@ class LFADS(object):
self
.
posterior_zs_g0
=
\
DiagonalGaussianFromInput
(
ic_enc
,
ic_dim
,
"ic_enc_2_post_g0"
,
var_min
=
hps
.
ic_post_var_min
)
if
kind
in
[
"train"
,
"posterior_sample_and_average"
]:
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"posterior_push_mean"
]:
zs_g0
=
self
.
posterior_zs_g0
else
:
zs_g0
=
self
.
prior_zs_g0
...
...
@@ -665,7 +662,7 @@ class LFADS(object):
recurrent_collections
=
[
'l2_con_reg'
])
with
tf
.
variable_scope
(
"con"
,
reuse
=
False
):
self
.
con_ics
=
tf
.
tile
(
tf
.
Variable
(
tf
.
zeros
([
1
,
hps
.
con_dim
*
con_cell
.
state_multiplier
]),
\
tf
.
Variable
(
tf
.
zeros
([
1
,
hps
.
con_dim
*
con_cell
.
state_multiplier
]),
name
=
"c0"
),
tf
.
stack
([
batch_size
,
1
]))
self
.
con_ics
.
set_shape
([
None
,
con_cell
.
state_size
])
# tile loses shape
...
...
@@ -711,8 +708,7 @@ class LFADS(object):
else
:
assert
False
,
"NIY"
# We support mulitple output distributions, for example Poisson, and also
# We support multiple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
...
...
@@ -769,6 +765,8 @@ class LFADS(object):
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
elif
kind
==
"posterior_sample_and_average"
:
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
elif
kind
==
"posterior_push_mean"
:
u_t
[
t
]
=
posterior_zs_co
[
t
].
mean
else
:
u_t
[
t
]
=
prior_zs_ar_con
.
samples_t
[
t
]
...
...
@@ -836,7 +834,7 @@ class LFADS(object):
self
.
recon_cost
=
tf
.
constant
(
0.0
)
# VAE reconstruction cost
self
.
nll_bound_vae
=
tf
.
constant
(
0.0
)
self
.
nll_bound_iwae
=
tf
.
constant
(
0.0
)
# for eval with IWAE cost.
if
kind
in
[
"train"
,
"posterior_sample_and_average"
]:
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"posterior_push_mean"
]:
kl_cost_g0_b
=
0.0
kl_cost_co_b
=
0.0
if
ic_dim
>
0
:
...
...
@@ -928,7 +926,7 @@ class LFADS(object):
tvars2
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/z/ic_enc_*'
)
self
.
train_vars
=
tvars
=
tvars1
+
tvars2
# train all variables
else
:
...
...
@@ -1595,6 +1593,9 @@ class LFADS(object):
do_eval_cost
=
False
,
do_average_batch
=
False
):
"""Returns all the goodies for the entire model, per batch.
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
in which case this handles the padding and truncating automatically
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
...
...
@@ -1614,6 +1615,19 @@ class LFADS(object):
enabled), the state of the generator, the factors, and the rates.
"""
session
=
tf
.
get_default_session
()
# if fewer than batch_size provided, pad to batch_size
hps
=
self
.
hps
batch_size
=
hps
.
batch_size
E
,
_
,
_
=
data_bxtxd
.
shape
if
E
<
hps
.
batch_size
:
data_bxtxd
=
np
.
pad
(
data_bxtxd
,
((
0
,
hps
.
batch_size
-
E
),
(
0
,
0
),
(
0
,
0
)),
mode
=
'constant'
,
constant_values
=
0
)
if
ext_input_bxtxi
is
not
None
:
ext_input_bxtxi
=
np
.
pad
(
ext_input_bxtxi
,
((
0
,
hps
.
batch_size
-
E
),
(
0
,
0
),
(
0
,
0
)),
mode
=
'constant'
,
constant_values
=
0
)
feed_dict
=
self
.
build_feed_dict
(
data_name
,
data_bxtxd
,
ext_input_bxtxi
,
keep_prob
=
1.0
)
...
...
@@ -1663,6 +1677,7 @@ class LFADS(object):
factors
=
list_t_bxn_to_tensor_bxtxn
(
factors
)
out_dist_params
=
list_t_bxn_to_tensor_bxtxn
(
out_dist_params
)
if
self
.
hps
.
ic_dim
>
0
:
# select first time point
prior_g0_mean
=
prior_g0_mean
[
0
]
prior_g0_logvar
=
prior_g0_logvar
[
0
]
post_g0_mean
=
post_g0_mean
[
0
]
...
...
@@ -1670,6 +1685,21 @@ class LFADS(object):
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
=
list_t_bxn_to_tensor_bxtxn
(
controller_outputs
)
# slice out the trials in case < batch_size provided
if
E
<
hps
.
batch_size
:
idx
=
np
.
arange
(
E
)
gen_ics
=
gen_ics
[
idx
,
:]
gen_states
=
gen_states
[
idx
,
:]
factors
=
factors
[
idx
,
:,
:]
out_dist_params
=
out_dist_params
[
idx
,
:,
:]
if
self
.
hps
.
ic_dim
>
0
:
prior_g0_mean
=
prior_g0_mean
[
idx
,
:]
prior_g0_logvar
=
prior_g0_logvar
[
idx
,
:]
post_g0_mean
=
post_g0_mean
[
idx
,
:]
post_g0_logvar
=
post_g0_logvar
[
idx
,
:]
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
=
controller_outputs
[
idx
,
:,
:]
if
do_average_batch
:
gen_ics
=
np
.
mean
(
gen_ics
,
axis
=
0
)
gen_states
=
np
.
mean
(
gen_states
,
axis
=
0
)
...
...
@@ -1730,7 +1760,6 @@ class LFADS(object):
E
,
T
,
D
=
data_extxd
.
shape
E_to_process
=
hps
.
ps_nexamples_to_process
if
E_to_process
>
E
:
print
(
"Setting number of posterior samples to process to : "
,
E
)
E_to_process
=
E
if
hps
.
ic_dim
>
0
:
...
...
@@ -1806,7 +1835,131 @@ class LFADS(object):
model_runs
[
'train_steps'
]
=
train_steps
return
model_runs
def
write_model_runs
(
self
,
datasets
,
output_fname
=
None
):
def
eval_model_runs_push_mean
(
self
,
data_name
,
data_extxd
,
ext_input_extxi
=
None
):
"""Returns values of interest for the model by pushing the means through
The mean values for both initial conditions and the control inputs are
pushed through the model instead of sampling (as is done in
eval_model_runs_avg_epoch).
This is a quick and approximate version of estimating these values instead
of sampling from the posterior many times and then averaging those values of
interest.
Internally, a total of batch_size trials are run through the model at once.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the estimated outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps
=
self
.
hps
batch_size
=
hps
.
batch_size
E
,
T
,
D
=
data_extxd
.
shape
E_to_process
=
hps
.
ps_nexamples_to_process
if
E_to_process
>
E
:
print
(
"Setting number of posterior samples to process to : "
,
E
)
E_to_process
=
E
if
hps
.
ic_dim
>
0
:
prior_g0_mean
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
prior_g0_logvar
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
post_g0_mean
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
post_g0_logvar
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
if
hps
.
co_dim
>
0
:
controller_outputs
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
co_dim
])
gen_ics
=
np
.
zeros
([
E_to_process
,
hps
.
gen_dim
])
gen_states
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
gen_dim
])
factors
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
factors_dim
])
if
hps
.
output_dist
==
'poisson'
:
out_dist_params
=
np
.
zeros
([
E_to_process
,
T
,
D
])
elif
hps
.
output_dist
==
'gaussian'
:
out_dist_params
=
np
.
zeros
([
E_to_process
,
T
,
D
+
D
])
else
:
assert
False
,
"NIY"
costs
=
np
.
zeros
(
E_to_process
)
nll_bound_vaes
=
np
.
zeros
(
E_to_process
)
nll_bound_iwaes
=
np
.
zeros
(
E_to_process
)
train_steps
=
np
.
zeros
(
E_to_process
)
# generator that will yield 0:N in groups of per items, e.g.
# (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
# this will be used to feed per=batch_size trials into the model at a time
def
trial_batches
(
N
,
per
):
for
i
in
range
(
0
,
N
,
per
):
yield
np
.
arange
(
i
,
min
(
i
+
per
,
N
),
dtype
=
np
.
int32
)
for
batch_idx
,
es_idx
in
enumerate
(
trial_batches
(
E_to_process
,
hps
.
batch_size
)):
print
(
"Running trial batch %d with %d trials"
%
(
batch_idx
+
1
,
len
(
es_idx
)))
data_bxtxd
,
ext_input_bxtxi
=
self
.
get_batch
(
data_extxd
,
ext_input_extxi
,
batch_size
=
batch_size
,
example_idxs
=
es_idx
)
model_values
=
self
.
eval_model_runs_batch
(
data_name
,
data_bxtxd
,
ext_input_bxtxi
,
do_eval_cost
=
True
,
do_average_batch
=
False
)
if
self
.
hps
.
ic_dim
>
0
:
prior_g0_mean
[
es_idx
,:]
=
model_values
[
'prior_g0_mean'
]
prior_g0_logvar
[
es_idx
,:]
=
model_values
[
'prior_g0_logvar'
]
post_g0_mean
[
es_idx
,:]
=
model_values
[
'post_g0_mean'
]
post_g0_logvar
[
es_idx
,:]
=
model_values
[
'post_g0_logvar'
]
gen_ics
[
es_idx
,:]
=
model_values
[
'gen_ics'
]
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
[
es_idx
,:,:]
=
model_values
[
'controller_outputs'
]
gen_states
[
es_idx
,:,:]
=
model_values
[
'gen_states'
]
factors
[
es_idx
,:,:]
=
model_values
[
'factors'
]
out_dist_params
[
es_idx
,:,:]
=
model_values
[
'output_dist_params'
]
# TODO
# model_values['costs'] and other costs come out as scalars, summed over
# all the trials in the batch. what we want is the per-trial costs
costs
[
es_idx
]
=
model_values
[
'costs'
]
nll_bound_vaes
[
es_idx
]
=
model_values
[
'nll_bound_vaes'
]
nll_bound_iwaes
[
es_idx
]
=
model_values
[
'nll_bound_iwaes'
]
train_steps
[
es_idx
]
=
model_values
[
'train_steps'
]
model_runs
=
{}
if
self
.
hps
.
ic_dim
>
0
:
model_runs
[
'prior_g0_mean'
]
=
prior_g0_mean
model_runs
[
'prior_g0_logvar'
]
=
prior_g0_logvar
model_runs
[
'post_g0_mean'
]
=
post_g0_mean
model_runs
[
'post_g0_logvar'
]
=
post_g0_logvar
model_runs
[
'gen_ics'
]
=
gen_ics
if
self
.
hps
.
co_dim
>
0
:
model_runs
[
'controller_outputs'
]
=
controller_outputs
model_runs
[
'gen_states'
]
=
gen_states
model_runs
[
'factors'
]
=
factors
model_runs
[
'output_dist_params'
]
=
out_dist_params
# You probably do not want the LL associated values when pushing the mean
# instead of sampling.
model_runs
[
'costs'
]
=
costs
model_runs
[
'nll_bound_vaes'
]
=
nll_bound_vaes
model_runs
[
'nll_bound_iwaes'
]
=
nll_bound_iwaes
model_runs
[
'train_steps'
]
=
train_steps
return
model_runs
def
write_model_runs
(
self
,
datasets
,
output_fname
=
None
,
push_mean
=
False
):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
...
...
@@ -1822,6 +1975,11 @@ class LFADS(object):
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
hps
=
self
.
hps
kind
=
hps
.
kind
...
...
@@ -1838,8 +1996,12 @@ class LFADS(object):
fname
=
output_fname
+
data_name
+
'_'
+
data_kind
+
'_'
+
kind
print
(
"Writing data for %s data and kind %s."
%
(
data_name
,
data_kind
))
model_runs
=
self
.
eval_model_runs_avg_epoch
(
data_name
,
data_extxd
,
ext_input_extxi
)
if
push_mean
:
model_runs
=
self
.
eval_model_runs_push_mean
(
data_name
,
data_extxd
,
ext_input_extxi
)
else
:
model_runs
=
self
.
eval_model_runs_avg_epoch
(
data_name
,
data_extxd
,
ext_input_extxi
)
full_fname
=
os
.
path
.
join
(
hps
.
lfads_save_dir
,
fname
)
write_data
(
full_fname
,
model_runs
,
compression
=
'gzip'
)
print
(
"Done."
)
...
...
research/lfads/run_lfads.py
View file @
74e35eda
...
...
@@ -99,6 +99,7 @@ flags = tf.app.flags
flags
.
DEFINE_string
(
"kind"
,
"train"
,
"Type of model to build {train,
\
posterior_sample_and_average,
\
posterior_push_mean,
\
prior_sample, write_model_params"
)
flags
.
DEFINE_string
(
"output_dist"
,
OUTPUT_DISTRIBUTION
,
"Type of output distribution, 'poisson' or 'gaussian'"
)
...
...
@@ -318,11 +319,10 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
# train up a model on one set of data, and then only further train the encoder
# on another set of data (the conditions to be tested) so that the model is
# forced to use the same dynamics to describe that data. If you don't care about
# that particular experiment, this flag should always be false.
flags
.
DEFINE_boolean
(
"do_train_encoder_only"
,
DO_TRAIN_ENCODER_ONLY
,
"Train only the encoder weights."
)
...
...
@@ -449,11 +449,11 @@ def build_model(hps, kind="train", datasets=None):
saver
.
restore
(
session
,
ckpt
.
model_checkpoint_path
)
else
:
print
(
"Created model with fresh parameters."
)
if
kind
in
[
"posterior_sample_and_average"
,
"prior_
sample
"
,
"write_model_params"
]:
if
kind
in
[
"posterior_sample_and_average"
,
"p
oste
rior_
push_mean
"
,
"prior_sample"
,
"write_model_params"
]:
print
(
"Possible error!!! You are running "
,
kind
,
" on a newly
\
initialized model!"
)
# cant print ckpt.model_check_point path if no ckpt
# can
no
t print ckpt.model_check_point path if no ckpt
print
(
"Are you sure you sure a checkpoint in "
,
hps
.
lfads_save_dir
,
" exists?"
)
...
...
@@ -609,7 +609,7 @@ def train(hps, datasets):
model
.
train_model
(
datasets
)
def
write_model_runs
(
hps
,
datasets
,
output_fname
=
None
):
def
write_model_runs
(
hps
,
datasets
,
output_fname
=
None
,
push_mean
=
False
):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
...
...
@@ -627,9 +627,14 @@ def write_model_runs(hps, datasets, output_fname=None):
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
model
=
build_model
(
hps
,
kind
=
hps
.
kind
,
datasets
=
datasets
)
model
.
write_model_runs
(
datasets
,
output_fname
)
model
.
write_model_runs
(
datasets
,
output_fname
,
push_mean
)
def
write_model_samples
(
hps
,
datasets
,
dataset_name
=
None
,
output_fname
=
None
):
...
...
@@ -759,8 +764,8 @@ def main(_):
# Read the data, if necessary.
train_set
=
valid_set
=
None
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"prior_
sample
"
,
"write_model_params"
]:
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"p
oste
rior_
push_mean
"
,
"prior_sample"
,
"write_model_params"
]:
datasets
=
load_datasets
(
hps
.
data_dir
,
hps
.
data_filename_stem
)
else
:
raise
ValueError
(
'Kind {} is not supported.'
.
format
(
kind
))
...
...
@@ -792,7 +797,11 @@ def main(_):
if
kind
==
"train"
:
train
(
hps
,
datasets
)
elif
kind
==
"posterior_sample_and_average"
:
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
)
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
,
push_mean
=
False
)
elif
kind
==
"posterior_push_mean"
:
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
,
push_mean
=
True
)
elif
kind
==
"prior_sample"
:
write_model_samples
(
hps
,
datasets
,
hps
.
output_filename_stem
)
elif
kind
==
"write_model_params"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment