Commit 20da056d authored by Dan O'Shea's avatar Dan O'Shea
Browse files

Cleaning comments for posterior_push_mean

parent 8c5c60ca
...@@ -490,11 +490,6 @@ class LFADS(object): ...@@ -490,11 +490,6 @@ class LFADS(object):
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
# def _case_with_no_default(pairs):
# def _default_value_fn():
# with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
# return tf.identity(pairs[0][1]())
# return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True) this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True) this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True) this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
...@@ -931,7 +926,7 @@ class LFADS(object): ...@@ -931,7 +926,7 @@ class LFADS(object):
tvars2 = \ tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*') scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2 self.train_vars = tvars = tvars1 + tvars2
# train all variables # train all variables
else: else:
...@@ -1765,7 +1760,6 @@ class LFADS(object): ...@@ -1765,7 +1760,6 @@ class LFADS(object):
E, T, D = data_extxd.shape E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process E_to_process = hps.ps_nexamples_to_process
if E_to_process > E: if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E E_to_process = E
if hps.ic_dim > 0: if hps.ic_dim > 0:
...@@ -1843,12 +1837,16 @@ class LFADS(object): ...@@ -1843,12 +1837,16 @@ class LFADS(object):
def eval_model_runs_push_mean(self, data_name, data_extxd, def eval_model_runs_push_mean(self, data_name, data_extxd,
ext_input_extxi=None): ext_input_extxi=None):
"""Returns the value for goodies for the entire model using the means """Returns values of interest for the model by pushing the means through
The expected value is taken over hidden (z) variables, namely the initial The mean values for both initial conditions and the control inputs are
conditions and the control inputs, by pushing the mean values for both pushed through the model instead of sampling (as is done in
through the model rather than by sampling (as in eval_model_runs_avg_epoch) eval_model_runs_avg_epoch).
A total of batch_size trials are run at a time. This is a quick and approximate version of estimating these values instead
of sampling from the posterior many times and then averaging those values of
interest.
Internally, a total of batch_size trials are run through the model at once.
Args: Args:
data_name: The name of the data dict, to select which in/out matrices data_name: The name of the data dict, to select which in/out matrices
...@@ -1859,7 +1857,7 @@ class LFADS(object): ...@@ -1859,7 +1857,7 @@ class LFADS(object):
shape: # examples x # time steps x # external input dims shape: # examples x # time steps x # external input dims
Returns: Returns:
A dictionary with the averaged outputs of the model decoder, namely: A dictionary with the estimated outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output enabled), the state of the generator, the factors, and the output
...@@ -1897,6 +1895,9 @@ class LFADS(object): ...@@ -1897,6 +1895,9 @@ class LFADS(object):
nll_bound_iwaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process) train_steps = np.zeros(E_to_process)
# generator that will yield 0:N in groups of per items, e.g.
# (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
# this will be used to feed per=batch_size trials into the model at a time
def trial_batches(N, per): def trial_batches(N, per):
for i in range(0, N, per): for i in range(0, N, per):
yield np.arange(i, min(i+per, N), dtype=np.int32) yield np.arange(i, min(i+per, N), dtype=np.int32)
...@@ -1949,6 +1950,9 @@ class LFADS(object): ...@@ -1949,6 +1950,9 @@ class LFADS(object):
model_runs['gen_states'] = gen_states model_runs['gen_states'] = gen_states
model_runs['factors'] = factors model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params model_runs['output_dist_params'] = out_dist_params
# You probably do not want the LL associated values when pushing the mean
# instead of sampling.
model_runs['costs'] = costs model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes model_runs['nll_bound_iwaes'] = nll_bound_iwaes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment