Commit 32daecc9 authored by Dan O'Shea's avatar Dan O'Shea
Browse files

Support for --kind=posterior_push_mean altnerative to sample and average

parent 6e3e5c38
...@@ -295,7 +295,8 @@ class LFADS(object): ...@@ -295,7 +295,8 @@ class LFADS(object):
datasets: a dictionary of named data_dictionaries, see top of lfads.py datasets: a dictionary of named data_dictionaries, see top of lfads.py
""" """
print("Building graph...") print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'prior_sample'] all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
'prior_sample']
assert kind in all_kinds, 'Wrong kind' assert kind in all_kinds, 'Wrong kind'
if hps.feedback_factors_or_rates == "rates": if hps.feedback_factors_or_rates == "rates":
assert len(hps.dataset_names) == 1, \ assert len(hps.dataset_names) == 1, \
...@@ -622,7 +623,8 @@ class LFADS(object): ...@@ -622,7 +623,8 @@ class LFADS(object):
self.posterior_zs_g0 = \ self.posterior_zs_g0 = \
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0", DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
var_min=hps.ic_post_var_min) var_min=hps.ic_post_var_min)
if kind in ["train", "posterior_sample_and_average"]: if kind in ["train", "posterior_sample_and_average",
"posterior_push_mean"]:
zs_g0 = self.posterior_zs_g0 zs_g0 = self.posterior_zs_g0
else: else:
zs_g0 = self.prior_zs_g0 zs_g0 = self.prior_zs_g0
...@@ -665,7 +667,7 @@ class LFADS(object): ...@@ -665,7 +667,7 @@ class LFADS(object):
recurrent_collections=['l2_con_reg']) recurrent_collections=['l2_con_reg'])
with tf.variable_scope("con", reuse=False): with tf.variable_scope("con", reuse=False):
self.con_ics = tf.tile( self.con_ics = tf.tile(
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), \ tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]),
name="c0"), name="c0"),
tf.stack([batch_size, 1])) tf.stack([batch_size, 1]))
self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
...@@ -711,8 +713,7 @@ class LFADS(object): ...@@ -711,8 +713,7 @@ class LFADS(object):
else: else:
assert False, "NIY" assert False, "NIY"
# We support multiple output distributions, for example Poisson, and also
# We support mulitple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two # Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params # parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st # tensor will variable sizes via tf.concat and tf.split, along the 1st
...@@ -769,6 +770,8 @@ class LFADS(object): ...@@ -769,6 +770,8 @@ class LFADS(object):
u_t[t] = posterior_zs_co[t].sample u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_sample_and_average": elif kind == "posterior_sample_and_average":
u_t[t] = posterior_zs_co[t].sample u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_push_mean":
u_t[t] = posterior_zs_co[t].mean
else: else:
u_t[t] = prior_zs_ar_con.samples_t[t] u_t[t] = prior_zs_ar_con.samples_t[t]
...@@ -836,7 +839,7 @@ class LFADS(object): ...@@ -836,7 +839,7 @@ class LFADS(object):
self.recon_cost = tf.constant(0.0) # VAE reconstruction cost self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
self.nll_bound_vae = tf.constant(0.0) self.nll_bound_vae = tf.constant(0.0)
self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost. self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
if kind in ["train", "posterior_sample_and_average"]: if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]:
kl_cost_g0_b = 0.0 kl_cost_g0_b = 0.0
kl_cost_co_b = 0.0 kl_cost_co_b = 0.0
if ic_dim > 0: if ic_dim > 0:
...@@ -1595,6 +1598,9 @@ class LFADS(object): ...@@ -1595,6 +1598,9 @@ class LFADS(object):
do_eval_cost=False, do_average_batch=False): do_eval_cost=False, do_average_batch=False):
"""Returns all the goodies for the entire model, per batch. """Returns all the goodies for the entire model, per batch.
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
in which case this handles the padding and truncating automatically
Args: Args:
data_name: The name of the data dict, to select which in/out matrices data_name: The name of the data dict, to select which in/out matrices
to use. to use.
...@@ -1614,6 +1620,19 @@ class LFADS(object): ...@@ -1614,6 +1620,19 @@ class LFADS(object):
enabled), the state of the generator, the factors, and the rates. enabled), the state of the generator, the factors, and the rates.
""" """
session = tf.get_default_session() session = tf.get_default_session()
# if fewer than batch_size provided, pad to batch_size
hps = self.hps
batch_size = hps.batch_size
E, _, _ = data_bxtxd.shape
if E < hps.batch_size:
data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
if ext_input_bxtxi is not None:
ext_input_bxtxi = np.pad(ext_input_bxtxi,
((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
feed_dict = self.build_feed_dict(data_name, data_bxtxd, feed_dict = self.build_feed_dict(data_name, data_bxtxd,
ext_input_bxtxi, keep_prob=1.0) ext_input_bxtxi, keep_prob=1.0)
...@@ -1663,6 +1682,7 @@ class LFADS(object): ...@@ -1663,6 +1682,7 @@ class LFADS(object):
factors = list_t_bxn_to_tensor_bxtxn(factors) factors = list_t_bxn_to_tensor_bxtxn(factors)
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params) out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
if self.hps.ic_dim > 0: if self.hps.ic_dim > 0:
# select first time point
prior_g0_mean = prior_g0_mean[0] prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0] prior_g0_logvar = prior_g0_logvar[0]
post_g0_mean = post_g0_mean[0] post_g0_mean = post_g0_mean[0]
...@@ -1670,6 +1690,21 @@ class LFADS(object): ...@@ -1670,6 +1690,21 @@ class LFADS(object):
if self.hps.co_dim > 0: if self.hps.co_dim > 0:
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs) controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
# slice out the trials in case < batch_size provided
if E < hps.batch_size:
idx = np.arange(E)
gen_ics = gen_ics[idx, :]
gen_states = gen_states[idx, :]
factors = factors[idx, :, :]
out_dist_params = out_dist_params[idx, :, :]
if self.hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[idx, :]
prior_g0_logvar = prior_g0_logvar[idx, :]
post_g0_mean = post_g0_mean[idx, :]
post_g0_logvar = post_g0_logvar[idx, :]
if self.hps.co_dim > 0:
controller_outputs = controller_outputs[idx, :, :]
if do_average_batch: if do_average_batch:
gen_ics = np.mean(gen_ics, axis=0) gen_ics = np.mean(gen_ics, axis=0)
gen_states = np.mean(gen_states, axis=0) gen_states = np.mean(gen_states, axis=0)
...@@ -1806,7 +1841,121 @@ class LFADS(object): ...@@ -1806,7 +1841,121 @@ class LFADS(object):
model_runs['train_steps'] = train_steps model_runs['train_steps'] = train_steps
return model_runs return model_runs
def write_model_runs(self, datasets, output_fname=None): def eval_model_runs_push_mean(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns the value for goodies for the entire model using the means
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs, by pushing the mean values for both
through the model rather than by sampling (as in eval_model_runs_avg_epoch)
A total of batch_size trials are run at a time.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
def trial_batches(N, per):
for i in range(0, N, per):
yield np.arange(i, min(i+per, N), dtype=np.int32)
for batch_idx, es_idx in enumerate(trial_batches(E_to_process,
hps.batch_size)):
print("Running trial batch %d with %d trials" % (batch_idx+1,
len(es_idx)))
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=es_idx)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=False)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
# TODO
# model_values['costs'] and other costs come out as scalars, summed over
# all the trials in the batch. what we want is the per-trial costs
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def write_model_runs(self, datasets, output_fname=None, push_mean=False):
"""Run the model on the data in data_dict, and save the computed values. """Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all LFADS generates a number of outputs for each examples, and these are all
...@@ -1822,6 +1971,11 @@ class LFADS(object): ...@@ -1822,6 +1971,11 @@ class LFADS(object):
Args: Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files. output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
""" """
hps = self.hps hps = self.hps
kind = hps.kind kind = hps.kind
...@@ -1838,8 +1992,12 @@ class LFADS(object): ...@@ -1838,8 +1992,12 @@ class LFADS(object):
fname = output_fname + data_name + '_' + data_kind + '_' + kind fname = output_fname + data_name + '_' + data_kind + '_' + kind
print("Writing data for %s data and kind %s." % (data_name, data_kind)) print("Writing data for %s data and kind %s." % (data_name, data_kind))
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd, if push_mean:
ext_input_extxi) model_runs = self.eval_model_runs_push_mean(data_name, data_extxd,
ext_input_extxi)
else:
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
ext_input_extxi)
full_fname = os.path.join(hps.lfads_save_dir, fname) full_fname = os.path.join(hps.lfads_save_dir, fname)
write_data(full_fname, model_runs, compression='gzip') write_data(full_fname, model_runs, compression='gzip')
print("Done.") print("Done.")
......
...@@ -99,6 +99,7 @@ flags = tf.app.flags ...@@ -99,6 +99,7 @@ flags = tf.app.flags
flags.DEFINE_string("kind", "train", flags.DEFINE_string("kind", "train",
"Type of model to build {train, \ "Type of model to build {train, \
posterior_sample_and_average, \ posterior_sample_and_average, \
posterior_push_mean, \
prior_sample, write_model_params") prior_sample, write_model_params")
flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION, flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
"Type of output distribution, 'poisson' or 'gaussian'") "Type of output distribution, 'poisson' or 'gaussian'")
...@@ -318,11 +319,10 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY, ...@@ -318,11 +319,10 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
# This flag is used for an experiment where one wants to know if the dynamics # This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might # learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on # train up a model on one set of data, and then only further train the encoder
# another set of data (the conditions to be tested) so that the model is forced # on another set of data (the conditions to be tested) so that the model is
# to use the same dynamics to describe that data. # forced to use the same dynamics to describe that data. If you don't care about
# If you don't care about that particular experiment, this flag should always be # that particular experiment, this flag should always be false.
# false.
flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY, flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
"Train only the encoder weights.") "Train only the encoder weights.")
...@@ -449,11 +449,11 @@ def build_model(hps, kind="train", datasets=None): ...@@ -449,11 +449,11 @@ def build_model(hps, kind="train", datasets=None):
saver.restore(session, ckpt.model_checkpoint_path) saver.restore(session, ckpt.model_checkpoint_path)
else: else:
print("Created model with fresh parameters.") print("Created model with fresh parameters.")
if kind in ["posterior_sample_and_average", "prior_sample", if kind in ["posterior_sample_and_average", "posterior_push_mean",
"write_model_params"]: "prior_sample", "write_model_params"]:
print("Possible error!!! You are running ", kind, " on a newly \ print("Possible error!!! You are running ", kind, " on a newly \
initialized model!") initialized model!")
# cant print ckpt.model_check_point path if no ckpt # cannot print ckpt.model_check_point path if no ckpt
print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir, print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
" exists?") " exists?")
...@@ -609,7 +609,7 @@ def train(hps, datasets): ...@@ -609,7 +609,7 @@ def train(hps, datasets):
model.train_model(datasets) model.train_model(datasets)
def write_model_runs(hps, datasets, output_fname=None): def write_model_runs(hps, datasets, output_fname=None, push_mean=False):
"""Run the model on the data in data_dict, and save the computed values. """Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all LFADS generates a number of outputs for each examples, and these are all
...@@ -627,9 +627,14 @@ def write_model_runs(hps, datasets, output_fname=None): ...@@ -627,9 +627,14 @@ def write_model_runs(hps, datasets, output_fname=None):
datasets: A dictionary of data dictionaries. The dataset dict is simply a datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py). name(string)-> data dictionary mapping (See top of lfads.py).
output_fname (optional): output filename stem to write the model runs. output_fname (optional): output filename stem to write the model runs.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
""" """
model = build_model(hps, kind=hps.kind, datasets=datasets) model = build_model(hps, kind=hps.kind, datasets=datasets)
model.write_model_runs(datasets, output_fname) model.write_model_runs(datasets, output_fname, push_mean)
def write_model_samples(hps, datasets, dataset_name=None, output_fname=None): def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
...@@ -759,8 +764,8 @@ def main(_): ...@@ -759,8 +764,8 @@ def main(_):
# Read the data, if necessary. # Read the data, if necessary.
train_set = valid_set = None train_set = valid_set = None
if kind in ["train", "posterior_sample_and_average", "prior_sample", if kind in ["train", "posterior_sample_and_average", "posterior_push_mean",
"write_model_params"]: "prior_sample", "write_model_params"]:
datasets = load_datasets(hps.data_dir, hps.data_filename_stem) datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
else: else:
raise ValueError('Kind {} is not supported.'.format(kind)) raise ValueError('Kind {} is not supported.'.format(kind))
...@@ -792,7 +797,11 @@ def main(_): ...@@ -792,7 +797,11 @@ def main(_):
if kind == "train": if kind == "train":
train(hps, datasets) train(hps, datasets)
elif kind == "posterior_sample_and_average": elif kind == "posterior_sample_and_average":
write_model_runs(hps, datasets, hps.output_filename_stem) write_model_runs(hps, datasets, hps.output_filename_stem,
push_mean=False)
elif kind == "posterior_push_mean":
write_model_runs(hps, datasets, hps.output_filename_stem,
push_mean=True)
elif kind == "prior_sample": elif kind == "prior_sample":
write_model_samples(hps, datasets, hps.output_filename_stem) write_model_samples(hps, datasets, hps.output_filename_stem)
elif kind == "write_model_params": elif kind == "write_model_params":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment