Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32daecc9
Commit
32daecc9
authored
Mar 22, 2018
by
Dan O'Shea
Browse files
Support for --kind=posterior_push_mean altnerative to sample and average
parent
6e3e5c38
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
22 deletions
+189
-22
research/lfads/lfads.py
research/lfads/lfads.py
+167
-9
research/lfads/run_lfads.py
research/lfads/run_lfads.py
+22
-13
No files found.
research/lfads/lfads.py
View file @
32daecc9
...
@@ -295,7 +295,8 @@ class LFADS(object):
...
@@ -295,7 +295,8 @@ class LFADS(object):
datasets: a dictionary of named data_dictionaries, see top of lfads.py
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
"""
print
(
"Building graph..."
)
print
(
"Building graph..."
)
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'prior_sample'
]
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'posterior_push_mean'
,
'prior_sample'
]
assert
kind
in
all_kinds
,
'Wrong kind'
assert
kind
in
all_kinds
,
'Wrong kind'
if
hps
.
feedback_factors_or_rates
==
"rates"
:
if
hps
.
feedback_factors_or_rates
==
"rates"
:
assert
len
(
hps
.
dataset_names
)
==
1
,
\
assert
len
(
hps
.
dataset_names
)
==
1
,
\
...
@@ -622,7 +623,8 @@ class LFADS(object):
...
@@ -622,7 +623,8 @@ class LFADS(object):
self
.
posterior_zs_g0
=
\
self
.
posterior_zs_g0
=
\
DiagonalGaussianFromInput
(
ic_enc
,
ic_dim
,
"ic_enc_2_post_g0"
,
DiagonalGaussianFromInput
(
ic_enc
,
ic_dim
,
"ic_enc_2_post_g0"
,
var_min
=
hps
.
ic_post_var_min
)
var_min
=
hps
.
ic_post_var_min
)
if
kind
in
[
"train"
,
"posterior_sample_and_average"
]:
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"posterior_push_mean"
]:
zs_g0
=
self
.
posterior_zs_g0
zs_g0
=
self
.
posterior_zs_g0
else
:
else
:
zs_g0
=
self
.
prior_zs_g0
zs_g0
=
self
.
prior_zs_g0
...
@@ -665,7 +667,7 @@ class LFADS(object):
...
@@ -665,7 +667,7 @@ class LFADS(object):
recurrent_collections
=
[
'l2_con_reg'
])
recurrent_collections
=
[
'l2_con_reg'
])
with
tf
.
variable_scope
(
"con"
,
reuse
=
False
):
with
tf
.
variable_scope
(
"con"
,
reuse
=
False
):
self
.
con_ics
=
tf
.
tile
(
self
.
con_ics
=
tf
.
tile
(
tf
.
Variable
(
tf
.
zeros
([
1
,
hps
.
con_dim
*
con_cell
.
state_multiplier
]),
\
tf
.
Variable
(
tf
.
zeros
([
1
,
hps
.
con_dim
*
con_cell
.
state_multiplier
]),
name
=
"c0"
),
name
=
"c0"
),
tf
.
stack
([
batch_size
,
1
]))
tf
.
stack
([
batch_size
,
1
]))
self
.
con_ics
.
set_shape
([
None
,
con_cell
.
state_size
])
# tile loses shape
self
.
con_ics
.
set_shape
([
None
,
con_cell
.
state_size
])
# tile loses shape
...
@@ -711,8 +713,7 @@ class LFADS(object):
...
@@ -711,8 +713,7 @@ class LFADS(object):
else
:
else
:
assert
False
,
"NIY"
assert
False
,
"NIY"
# We support multiple output distributions, for example Poisson, and also
# We support mulitple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
# tensor will variable sizes via tf.concat and tf.split, along the 1st
...
@@ -769,6 +770,8 @@ class LFADS(object):
...
@@ -769,6 +770,8 @@ class LFADS(object):
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
elif
kind
==
"posterior_sample_and_average"
:
elif
kind
==
"posterior_sample_and_average"
:
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
u_t
[
t
]
=
posterior_zs_co
[
t
].
sample
elif
kind
==
"posterior_push_mean"
:
u_t
[
t
]
=
posterior_zs_co
[
t
].
mean
else
:
else
:
u_t
[
t
]
=
prior_zs_ar_con
.
samples_t
[
t
]
u_t
[
t
]
=
prior_zs_ar_con
.
samples_t
[
t
]
...
@@ -836,7 +839,7 @@ class LFADS(object):
...
@@ -836,7 +839,7 @@ class LFADS(object):
self
.
recon_cost
=
tf
.
constant
(
0.0
)
# VAE reconstruction cost
self
.
recon_cost
=
tf
.
constant
(
0.0
)
# VAE reconstruction cost
self
.
nll_bound_vae
=
tf
.
constant
(
0.0
)
self
.
nll_bound_vae
=
tf
.
constant
(
0.0
)
self
.
nll_bound_iwae
=
tf
.
constant
(
0.0
)
# for eval with IWAE cost.
self
.
nll_bound_iwae
=
tf
.
constant
(
0.0
)
# for eval with IWAE cost.
if
kind
in
[
"train"
,
"posterior_sample_and_average"
]:
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"posterior_push_mean"
]:
kl_cost_g0_b
=
0.0
kl_cost_g0_b
=
0.0
kl_cost_co_b
=
0.0
kl_cost_co_b
=
0.0
if
ic_dim
>
0
:
if
ic_dim
>
0
:
...
@@ -1595,6 +1598,9 @@ class LFADS(object):
...
@@ -1595,6 +1598,9 @@ class LFADS(object):
do_eval_cost
=
False
,
do_average_batch
=
False
):
do_eval_cost
=
False
,
do_average_batch
=
False
):
"""Returns all the goodies for the entire model, per batch.
"""Returns all the goodies for the entire model, per batch.
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
in which case this handles the padding and truncating automatically
Args:
Args:
data_name: The name of the data dict, to select which in/out matrices
data_name: The name of the data dict, to select which in/out matrices
to use.
to use.
...
@@ -1614,6 +1620,19 @@ class LFADS(object):
...
@@ -1614,6 +1620,19 @@ class LFADS(object):
enabled), the state of the generator, the factors, and the rates.
enabled), the state of the generator, the factors, and the rates.
"""
"""
session
=
tf
.
get_default_session
()
session
=
tf
.
get_default_session
()
# if fewer than batch_size provided, pad to batch_size
hps
=
self
.
hps
batch_size
=
hps
.
batch_size
E
,
_
,
_
=
data_bxtxd
.
shape
if
E
<
hps
.
batch_size
:
data_bxtxd
=
np
.
pad
(
data_bxtxd
,
((
0
,
hps
.
batch_size
-
E
),
(
0
,
0
),
(
0
,
0
)),
mode
=
'constant'
,
constant_values
=
0
)
if
ext_input_bxtxi
is
not
None
:
ext_input_bxtxi
=
np
.
pad
(
ext_input_bxtxi
,
((
0
,
hps
.
batch_size
-
E
),
(
0
,
0
),
(
0
,
0
)),
mode
=
'constant'
,
constant_values
=
0
)
feed_dict
=
self
.
build_feed_dict
(
data_name
,
data_bxtxd
,
feed_dict
=
self
.
build_feed_dict
(
data_name
,
data_bxtxd
,
ext_input_bxtxi
,
keep_prob
=
1.0
)
ext_input_bxtxi
,
keep_prob
=
1.0
)
...
@@ -1663,6 +1682,7 @@ class LFADS(object):
...
@@ -1663,6 +1682,7 @@ class LFADS(object):
factors
=
list_t_bxn_to_tensor_bxtxn
(
factors
)
factors
=
list_t_bxn_to_tensor_bxtxn
(
factors
)
out_dist_params
=
list_t_bxn_to_tensor_bxtxn
(
out_dist_params
)
out_dist_params
=
list_t_bxn_to_tensor_bxtxn
(
out_dist_params
)
if
self
.
hps
.
ic_dim
>
0
:
if
self
.
hps
.
ic_dim
>
0
:
# select first time point
prior_g0_mean
=
prior_g0_mean
[
0
]
prior_g0_mean
=
prior_g0_mean
[
0
]
prior_g0_logvar
=
prior_g0_logvar
[
0
]
prior_g0_logvar
=
prior_g0_logvar
[
0
]
post_g0_mean
=
post_g0_mean
[
0
]
post_g0_mean
=
post_g0_mean
[
0
]
...
@@ -1670,6 +1690,21 @@ class LFADS(object):
...
@@ -1670,6 +1690,21 @@ class LFADS(object):
if
self
.
hps
.
co_dim
>
0
:
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
=
list_t_bxn_to_tensor_bxtxn
(
controller_outputs
)
controller_outputs
=
list_t_bxn_to_tensor_bxtxn
(
controller_outputs
)
# slice out the trials in case < batch_size provided
if
E
<
hps
.
batch_size
:
idx
=
np
.
arange
(
E
)
gen_ics
=
gen_ics
[
idx
,
:]
gen_states
=
gen_states
[
idx
,
:]
factors
=
factors
[
idx
,
:,
:]
out_dist_params
=
out_dist_params
[
idx
,
:,
:]
if
self
.
hps
.
ic_dim
>
0
:
prior_g0_mean
=
prior_g0_mean
[
idx
,
:]
prior_g0_logvar
=
prior_g0_logvar
[
idx
,
:]
post_g0_mean
=
post_g0_mean
[
idx
,
:]
post_g0_logvar
=
post_g0_logvar
[
idx
,
:]
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
=
controller_outputs
[
idx
,
:,
:]
if
do_average_batch
:
if
do_average_batch
:
gen_ics
=
np
.
mean
(
gen_ics
,
axis
=
0
)
gen_ics
=
np
.
mean
(
gen_ics
,
axis
=
0
)
gen_states
=
np
.
mean
(
gen_states
,
axis
=
0
)
gen_states
=
np
.
mean
(
gen_states
,
axis
=
0
)
...
@@ -1806,7 +1841,121 @@ class LFADS(object):
...
@@ -1806,7 +1841,121 @@ class LFADS(object):
model_runs
[
'train_steps'
]
=
train_steps
model_runs
[
'train_steps'
]
=
train_steps
return
model_runs
return
model_runs
def
write_model_runs
(
self
,
datasets
,
output_fname
=
None
):
def
eval_model_runs_push_mean
(
self
,
data_name
,
data_extxd
,
ext_input_extxi
=
None
):
"""Returns the value for goodies for the entire model using the means
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs, by pushing the mean values for both
through the model rather than by sampling (as in eval_model_runs_avg_epoch)
A total of batch_size trials are run at a time.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps
=
self
.
hps
batch_size
=
hps
.
batch_size
E
,
T
,
D
=
data_extxd
.
shape
E_to_process
=
hps
.
ps_nexamples_to_process
if
E_to_process
>
E
:
print
(
"Setting number of posterior samples to process to : "
,
E
)
E_to_process
=
E
if
hps
.
ic_dim
>
0
:
prior_g0_mean
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
prior_g0_logvar
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
post_g0_mean
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
post_g0_logvar
=
np
.
zeros
([
E_to_process
,
hps
.
ic_dim
])
if
hps
.
co_dim
>
0
:
controller_outputs
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
co_dim
])
gen_ics
=
np
.
zeros
([
E_to_process
,
hps
.
gen_dim
])
gen_states
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
gen_dim
])
factors
=
np
.
zeros
([
E_to_process
,
T
,
hps
.
factors_dim
])
if
hps
.
output_dist
==
'poisson'
:
out_dist_params
=
np
.
zeros
([
E_to_process
,
T
,
D
])
elif
hps
.
output_dist
==
'gaussian'
:
out_dist_params
=
np
.
zeros
([
E_to_process
,
T
,
D
+
D
])
else
:
assert
False
,
"NIY"
costs
=
np
.
zeros
(
E_to_process
)
nll_bound_vaes
=
np
.
zeros
(
E_to_process
)
nll_bound_iwaes
=
np
.
zeros
(
E_to_process
)
train_steps
=
np
.
zeros
(
E_to_process
)
def
trial_batches
(
N
,
per
):
for
i
in
range
(
0
,
N
,
per
):
yield
np
.
arange
(
i
,
min
(
i
+
per
,
N
),
dtype
=
np
.
int32
)
for
batch_idx
,
es_idx
in
enumerate
(
trial_batches
(
E_to_process
,
hps
.
batch_size
)):
print
(
"Running trial batch %d with %d trials"
%
(
batch_idx
+
1
,
len
(
es_idx
)))
data_bxtxd
,
ext_input_bxtxi
=
self
.
get_batch
(
data_extxd
,
ext_input_extxi
,
batch_size
=
batch_size
,
example_idxs
=
es_idx
)
model_values
=
self
.
eval_model_runs_batch
(
data_name
,
data_bxtxd
,
ext_input_bxtxi
,
do_eval_cost
=
True
,
do_average_batch
=
False
)
if
self
.
hps
.
ic_dim
>
0
:
prior_g0_mean
[
es_idx
,:]
=
model_values
[
'prior_g0_mean'
]
prior_g0_logvar
[
es_idx
,:]
=
model_values
[
'prior_g0_logvar'
]
post_g0_mean
[
es_idx
,:]
=
model_values
[
'post_g0_mean'
]
post_g0_logvar
[
es_idx
,:]
=
model_values
[
'post_g0_logvar'
]
gen_ics
[
es_idx
,:]
=
model_values
[
'gen_ics'
]
if
self
.
hps
.
co_dim
>
0
:
controller_outputs
[
es_idx
,:,:]
=
model_values
[
'controller_outputs'
]
gen_states
[
es_idx
,:,:]
=
model_values
[
'gen_states'
]
factors
[
es_idx
,:,:]
=
model_values
[
'factors'
]
out_dist_params
[
es_idx
,:,:]
=
model_values
[
'output_dist_params'
]
# TODO
# model_values['costs'] and other costs come out as scalars, summed over
# all the trials in the batch. what we want is the per-trial costs
costs
[
es_idx
]
=
model_values
[
'costs'
]
nll_bound_vaes
[
es_idx
]
=
model_values
[
'nll_bound_vaes'
]
nll_bound_iwaes
[
es_idx
]
=
model_values
[
'nll_bound_iwaes'
]
train_steps
[
es_idx
]
=
model_values
[
'train_steps'
]
model_runs
=
{}
if
self
.
hps
.
ic_dim
>
0
:
model_runs
[
'prior_g0_mean'
]
=
prior_g0_mean
model_runs
[
'prior_g0_logvar'
]
=
prior_g0_logvar
model_runs
[
'post_g0_mean'
]
=
post_g0_mean
model_runs
[
'post_g0_logvar'
]
=
post_g0_logvar
model_runs
[
'gen_ics'
]
=
gen_ics
if
self
.
hps
.
co_dim
>
0
:
model_runs
[
'controller_outputs'
]
=
controller_outputs
model_runs
[
'gen_states'
]
=
gen_states
model_runs
[
'factors'
]
=
factors
model_runs
[
'output_dist_params'
]
=
out_dist_params
model_runs
[
'costs'
]
=
costs
model_runs
[
'nll_bound_vaes'
]
=
nll_bound_vaes
model_runs
[
'nll_bound_iwaes'
]
=
nll_bound_iwaes
model_runs
[
'train_steps'
]
=
train_steps
return
model_runs
def
write_model_runs
(
self
,
datasets
,
output_fname
=
None
,
push_mean
=
False
):
"""Run the model on the data in data_dict, and save the computed values.
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
LFADS generates a number of outputs for each examples, and these are all
...
@@ -1822,6 +1971,11 @@ class LFADS(object):
...
@@ -1822,6 +1971,11 @@ class LFADS(object):
Args:
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
"""
hps
=
self
.
hps
hps
=
self
.
hps
kind
=
hps
.
kind
kind
=
hps
.
kind
...
@@ -1838,8 +1992,12 @@ class LFADS(object):
...
@@ -1838,8 +1992,12 @@ class LFADS(object):
fname
=
output_fname
+
data_name
+
'_'
+
data_kind
+
'_'
+
kind
fname
=
output_fname
+
data_name
+
'_'
+
data_kind
+
'_'
+
kind
print
(
"Writing data for %s data and kind %s."
%
(
data_name
,
data_kind
))
print
(
"Writing data for %s data and kind %s."
%
(
data_name
,
data_kind
))
model_runs
=
self
.
eval_model_runs_avg_epoch
(
data_name
,
data_extxd
,
if
push_mean
:
ext_input_extxi
)
model_runs
=
self
.
eval_model_runs_push_mean
(
data_name
,
data_extxd
,
ext_input_extxi
)
else
:
model_runs
=
self
.
eval_model_runs_avg_epoch
(
data_name
,
data_extxd
,
ext_input_extxi
)
full_fname
=
os
.
path
.
join
(
hps
.
lfads_save_dir
,
fname
)
full_fname
=
os
.
path
.
join
(
hps
.
lfads_save_dir
,
fname
)
write_data
(
full_fname
,
model_runs
,
compression
=
'gzip'
)
write_data
(
full_fname
,
model_runs
,
compression
=
'gzip'
)
print
(
"Done."
)
print
(
"Done."
)
...
...
research/lfads/run_lfads.py
View file @
32daecc9
...
@@ -99,6 +99,7 @@ flags = tf.app.flags
...
@@ -99,6 +99,7 @@ flags = tf.app.flags
flags
.
DEFINE_string
(
"kind"
,
"train"
,
flags
.
DEFINE_string
(
"kind"
,
"train"
,
"Type of model to build {train,
\
"Type of model to build {train,
\
posterior_sample_and_average,
\
posterior_sample_and_average,
\
posterior_push_mean,
\
prior_sample, write_model_params"
)
prior_sample, write_model_params"
)
flags
.
DEFINE_string
(
"output_dist"
,
OUTPUT_DISTRIBUTION
,
flags
.
DEFINE_string
(
"output_dist"
,
OUTPUT_DISTRIBUTION
,
"Type of output distribution, 'poisson' or 'gaussian'"
)
"Type of output distribution, 'poisson' or 'gaussian'"
)
...
@@ -318,11 +319,10 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
...
@@ -318,11 +319,10 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
# This flag is used for an experiment where one wants to know if the dynamics
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# train up a model on one set of data, and then only further train the encoder
# another set of data (the conditions to be tested) so that the model is forced
# on another set of data (the conditions to be tested) so that the model is
# to use the same dynamics to describe that data.
# forced to use the same dynamics to describe that data. If you don't care about
# If you don't care about that particular experiment, this flag should always be
# that particular experiment, this flag should always be false.
# false.
flags
.
DEFINE_boolean
(
"do_train_encoder_only"
,
DO_TRAIN_ENCODER_ONLY
,
flags
.
DEFINE_boolean
(
"do_train_encoder_only"
,
DO_TRAIN_ENCODER_ONLY
,
"Train only the encoder weights."
)
"Train only the encoder weights."
)
...
@@ -449,11 +449,11 @@ def build_model(hps, kind="train", datasets=None):
...
@@ -449,11 +449,11 @@ def build_model(hps, kind="train", datasets=None):
saver
.
restore
(
session
,
ckpt
.
model_checkpoint_path
)
saver
.
restore
(
session
,
ckpt
.
model_checkpoint_path
)
else
:
else
:
print
(
"Created model with fresh parameters."
)
print
(
"Created model with fresh parameters."
)
if
kind
in
[
"posterior_sample_and_average"
,
"prior_
sample
"
,
if
kind
in
[
"posterior_sample_and_average"
,
"p
oste
rior_
push_mean
"
,
"write_model_params"
]:
"prior_sample"
,
"write_model_params"
]:
print
(
"Possible error!!! You are running "
,
kind
,
" on a newly
\
print
(
"Possible error!!! You are running "
,
kind
,
" on a newly
\
initialized model!"
)
initialized model!"
)
# cant print ckpt.model_check_point path if no ckpt
# can
no
t print ckpt.model_check_point path if no ckpt
print
(
"Are you sure you sure a checkpoint in "
,
hps
.
lfads_save_dir
,
print
(
"Are you sure you sure a checkpoint in "
,
hps
.
lfads_save_dir
,
" exists?"
)
" exists?"
)
...
@@ -609,7 +609,7 @@ def train(hps, datasets):
...
@@ -609,7 +609,7 @@ def train(hps, datasets):
model
.
train_model
(
datasets
)
model
.
train_model
(
datasets
)
def
write_model_runs
(
hps
,
datasets
,
output_fname
=
None
):
def
write_model_runs
(
hps
,
datasets
,
output_fname
=
None
,
push_mean
=
False
):
"""Run the model on the data in data_dict, and save the computed values.
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
LFADS generates a number of outputs for each examples, and these are all
...
@@ -627,9 +627,14 @@ def write_model_runs(hps, datasets, output_fname=None):
...
@@ -627,9 +627,14 @@ def write_model_runs(hps, datasets, output_fname=None):
datasets: A dictionary of data dictionaries. The dataset dict is simply a
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs.
output_fname (optional): output filename stem to write the model runs.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
"""
model
=
build_model
(
hps
,
kind
=
hps
.
kind
,
datasets
=
datasets
)
model
=
build_model
(
hps
,
kind
=
hps
.
kind
,
datasets
=
datasets
)
model
.
write_model_runs
(
datasets
,
output_fname
)
model
.
write_model_runs
(
datasets
,
output_fname
,
push_mean
)
def
write_model_samples
(
hps
,
datasets
,
dataset_name
=
None
,
output_fname
=
None
):
def
write_model_samples
(
hps
,
datasets
,
dataset_name
=
None
,
output_fname
=
None
):
...
@@ -759,8 +764,8 @@ def main(_):
...
@@ -759,8 +764,8 @@ def main(_):
# Read the data, if necessary.
# Read the data, if necessary.
train_set
=
valid_set
=
None
train_set
=
valid_set
=
None
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"prior_
sample
"
,
if
kind
in
[
"train"
,
"posterior_sample_and_average"
,
"p
oste
rior_
push_mean
"
,
"write_model_params"
]:
"prior_sample"
,
"write_model_params"
]:
datasets
=
load_datasets
(
hps
.
data_dir
,
hps
.
data_filename_stem
)
datasets
=
load_datasets
(
hps
.
data_dir
,
hps
.
data_filename_stem
)
else
:
else
:
raise
ValueError
(
'Kind {} is not supported.'
.
format
(
kind
))
raise
ValueError
(
'Kind {} is not supported.'
.
format
(
kind
))
...
@@ -792,7 +797,11 @@ def main(_):
...
@@ -792,7 +797,11 @@ def main(_):
if
kind
==
"train"
:
if
kind
==
"train"
:
train
(
hps
,
datasets
)
train
(
hps
,
datasets
)
elif
kind
==
"posterior_sample_and_average"
:
elif
kind
==
"posterior_sample_and_average"
:
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
)
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
,
push_mean
=
False
)
elif
kind
==
"posterior_push_mean"
:
write_model_runs
(
hps
,
datasets
,
hps
.
output_filename_stem
,
push_mean
=
True
)
elif
kind
==
"prior_sample"
:
elif
kind
==
"prior_sample"
:
write_model_samples
(
hps
,
datasets
,
hps
.
output_filename_stem
)
write_model_samples
(
hps
,
datasets
,
hps
.
output_filename_stem
)
elif
kind
==
"write_model_params"
:
elif
kind
==
"write_model_params"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment