Unverified Commit 730b778e authored by Alexa Nguyen's avatar Alexa Nguyen Committed by GitHub
Browse files

Fix typos and indentation in lfads.py (#10699)

parent bf868b99
...@@ -37,15 +37,15 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys: ...@@ -37,15 +37,15 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys:
'train_ext_input' and 'valid_ext_input', if there are know external inputs 'train_ext_input' and 'valid_ext_input', if there are know external inputs
to the system being modeled, these take on dimensions: to the system being modeled, these take on dimensions:
ExTxI, E - # examples, T - # time steps, I = # dimensions in input. ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
'alignment_matrix_cxf' - If you are using multiple days data, it's possible 'alignment_matrix_cxf' - If you are using multiple days data, it's possible
that one can align the channels (see manuscript). If so each dataset will that one can align the channels (see manuscript). If so each dataset will
contain this matrix, which will be used for both the input adapter and the contain this matrix, which will be used for both the input adapter and the
output adapter for each dataset. These matrices, if provided, must be of output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag. on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to 'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions. bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials, If one runs LFADS on data where the true rates are known for some trials,
...@@ -85,13 +85,13 @@ class GRU(object): ...@@ -85,13 +85,13 @@ class GRU(object):
"""Create a GRU object. """Create a GRU object.
Args: Args:
num_units: Number of units in the GRU num_units: Number of units in the GRU.
forget_bias (optional): Hack to help learning. forget_bias (optional): Hack to help learning.
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with weight_scale (optional): Weights are scaled by ws/sqrt(#inputs), with
ws being the weight scale. ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value, clip_value (optional): If the recurrent values grow above this value,
clip them. clip them.
collections (optional): List of additonal collections variables should collections (optional): List of additional collections variables should
belong to. belong to.
""" """
self._num_units = num_units self._num_units = num_units
...@@ -171,17 +171,17 @@ class GenGRU(object): ...@@ -171,17 +171,17 @@ class GenGRU(object):
"""Create a GRU object. """Create a GRU object.
Args: Args:
num_units: Number of units in the GRU num_units: Number of units in the GRU.
forget_bias (optional): Hack to help learning. forget_bias (optional): Hack to help learning.
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with input_weight_scale (optional): Weights are scaled ws/sqrt(#inputs), with
ws being the weight scale. ws being the weight scale.
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs), rec_weight_scale (optional): Weights are scaled ws/sqrt(#inputs),
with ws being the weight scale. with ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value, clip_value (optional): If the recurrent values grow above this value,
clip them. clip them.
input_collections (optional): List of additonal collections variables input_collections (optional): List of additional collections variables
that input->rec weights should belong to. that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables recurrent_collections (optional): List of additional collections variables
that rec->rec weights should belong to. that rec->rec weights should belong to.
""" """
self._num_units = num_units self._num_units = num_units
...@@ -271,7 +271,7 @@ class LFADS(object): ...@@ -271,7 +271,7 @@ class LFADS(object):
various factors, such as an initial condition, a generative various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors. dimensional description of the observed data, called the factors.
Additoinally, the observations have a noise model (in this case Additionally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed (e.g. underlying rates of a Poisson distribution given the observed
event counts). event counts).
...@@ -291,8 +291,8 @@ class LFADS(object): ...@@ -291,8 +291,8 @@ class LFADS(object):
Args: Args:
hps: The dictionary of hyper parameters. hps: The dictionary of hyper parameters.
kind: the type of model to build (see above). kind: The type of model to build (see above).
datasets: a dictionary of named data_dictionaries, see top of lfads.py datasets: A dictionary of named data_dictionaries, see top of lfads.py
""" """
print("Building graph...") print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean', all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
...@@ -905,7 +905,7 @@ class LFADS(object): ...@@ -905,7 +905,7 @@ class LFADS(object):
if kind != "train": if kind != "train":
# save every so often # save every so often
self.seso_saver = tf.train.Saver(tf.global_variables(), self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep) max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error # lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(), self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep_lve) max_to_keep=hps.max_ckpt_to_keep_lve)
...@@ -952,7 +952,7 @@ class LFADS(object): ...@@ -952,7 +952,7 @@ class LFADS(object):
zip(grads, tvars), global_step=self.train_step) zip(grads, tvars), global_step=self.train_step)
self.seso_saver = tf.train.Saver(tf.global_variables(), self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep) max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error # lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(), self.lve_saver = tf.train.Saver(tf.global_variables(),
...@@ -963,7 +963,7 @@ class LFADS(object): ...@@ -963,7 +963,7 @@ class LFADS(object):
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3], self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
name='image_tensor') name='image_tensor')
self.example_summ = tf.summary.image("LFADS example", self.example_image, self.example_summ = tf.summary.image("LFADS example", self.example_image,
collections=["example_summaries"]) collections=["example_summaries"])
# general training summaries # general training summaries
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate) self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
...@@ -1032,8 +1032,8 @@ class LFADS(object): ...@@ -1032,8 +1032,8 @@ class LFADS(object):
Args: Args:
train_name: The key into the datasets, to set the tf.case statement for train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices. the proper readin / readout matrices.
data_bxtxd: The data tensor data_bxtxd: The data tensor.
ext_input_bxtxi (optional): The external input tensor ext_input_bxtxi (optional): The external input tensor.
keep_prob: The drop out keep probability. keep_prob: The drop out keep probability.
Returns: Returns:
...@@ -1066,7 +1066,7 @@ class LFADS(object): ...@@ -1066,7 +1066,7 @@ class LFADS(object):
# examples x # time steps x # dimensions # examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape: ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions # examples x # time steps x # external input dimensions
batch_size: The size of the batch to return batch_size: The size of the batch to return.
example_idxs (optional): The example indices used to select examples. example_idxs (optional): The example indices used to select examples.
Returns: Returns:
...@@ -1123,8 +1123,8 @@ class LFADS(object): ...@@ -1123,8 +1123,8 @@ class LFADS(object):
is managed by drawing randomly from 1:nexamples. is managed by drawing randomly from 1:nexamples.
Args: Args:
nexamples: number of examples to randomize nexamples: Number of examples to randomize.
batch_size: number of elements in batch batch_size: Number of elements in batch.
Returns: Returns:
The randomized, properly shaped indicies. The randomized, properly shaped indicies.
...@@ -1148,7 +1148,7 @@ class LFADS(object): ...@@ -1148,7 +1148,7 @@ class LFADS(object):
enough to pick up dynamics that you may not want. enough to pick up dynamics that you may not want.
Args: Args:
data_bxtxd: numpy array of spike count data to be shuffled. data_bxtxd: Numpy array of spike count data to be shuffled.
Returns: Returns:
S_bxtxd, a numpy array with the same dimensions and contents as S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately. data_bxtxd, but shuffled appropriately.
...@@ -1231,7 +1231,7 @@ class LFADS(object): ...@@ -1231,7 +1231,7 @@ class LFADS(object):
Args: Args:
datasets: A dict of data dicts. The dataset dict is simply a datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py). name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional): The batch_size to use batch_size (optional): The batch_size to use.
do_save_ckpt (optional): Should the routine save a checkpoint on this do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch? training epoch?
...@@ -1283,7 +1283,7 @@ class LFADS(object): ...@@ -1283,7 +1283,7 @@ class LFADS(object):
name(string)-> data dictionary mapping (See top of lfads.py). name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call. the tf.session.run() call.
batch_size (optional): The batch_size to use batch_size (optional): The batch_size to use.
do_collect (optional): Should the routine collect all session.run do_collect (optional): Should the routine collect all session.run
output as a list, and return it? output as a list, and return it?
keep_prob (optional): The dropout keep probability. keep_prob (optional): The dropout keep probability.
...@@ -1330,7 +1330,7 @@ class LFADS(object): ...@@ -1330,7 +1330,7 @@ class LFADS(object):
Args: Args:
datasets, the dictionary of datasets used in the study. datasets, the dictionary of datasets used in the study.
summary_values: These summary values are created from the training loop, summary_values: These summary values are created from the training loop,
and so summarize the entire set of datasets. and so summarize the entire set of datasets.
""" """
hps = self.hps hps = self.hps
...@@ -1599,12 +1599,12 @@ class LFADS(object): ...@@ -1599,12 +1599,12 @@ class LFADS(object):
Args: Args:
data_name: The name of the data dict, to select which in/out matrices data_name: The name of the data dict, to select which in/out matrices
to use. to use.
data_bxtxd: Numpy array training data with shape: data_bxtxd: Numpy array training data with shape:
batch_size x # time steps x # dimensions batch_size x # time steps x # dimensions
ext_input_bxtxi: Numpy array training external input with shape: ext_input_bxtxi: Numpy array training external input with shape:
batch_size x # time steps x # external input dims batch_size x # time steps x # external input dims
do_eval_cost (optional): If true, the IWAE (Importance Weighted do_eval_cost (optional): If true, the IWAE (Importance Weighted
Autoencoder) log likeihood bound, instead of the VAE version. Autoencoder) log likeihood bound, instead of the VAE version.
do_average_batch (optional): average over the batch, useful for getting do_average_batch (optional): average over the batch, useful for getting
good IWAE costs, and model outputs for a single data point. good IWAE costs, and model outputs for a single data point.
...@@ -1743,7 +1743,7 @@ class LFADS(object): ...@@ -1743,7 +1743,7 @@ class LFADS(object):
Args: Args:
data_name: The name of the data dict, to select which in/out matrices data_name: The name of the data dict, to select which in/out matrices
to use. to use.
data_extxd: Numpy array training data with shape: data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions # examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims shape: # examples x # time steps x # external input dims
...@@ -1837,7 +1837,7 @@ class LFADS(object): ...@@ -1837,7 +1837,7 @@ class LFADS(object):
def eval_model_runs_push_mean(self, data_name, data_extxd, def eval_model_runs_push_mean(self, data_name, data_extxd,
ext_input_extxi=None): ext_input_extxi=None):
"""Returns values of interest for the model by pushing the means through """Returns values of interest for the model by pushing the means through
The mean values for both initial conditions and the control inputs are The mean values for both initial conditions and the control inputs are
pushed through the model instead of sampling (as is done in pushed through the model instead of sampling (as is done in
...@@ -1851,7 +1851,7 @@ class LFADS(object): ...@@ -1851,7 +1851,7 @@ class LFADS(object):
Args: Args:
data_name: The name of the data dict, to select which in/out matrices data_name: The name of the data dict, to select which in/out matrices
to use. to use.
data_extxd: Numpy array training data with shape: data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions # examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims shape: # examples x # time steps x # external input dims
...@@ -1966,16 +1966,16 @@ class LFADS(object): ...@@ -1966,16 +1966,16 @@ class LFADS(object):
saved. They are: saved. They are:
The mean and variance of the prior of g0. The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0. The mean and variance of approximate posterior of g0.
The control inputs (if enabled) The control inputs (if enabled).
The initial conditions, g0, for all examples. The initial conditions, g0, for all examples.
The generator states for all time. The generator states for all time.
The factors for all time. The factors for all time.
The output distribution parameters (e.g. rates) for all time. The output distribution parameters (e.g. rates) for all time.
Args: Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py datasets: A dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files. output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial push_mean: If False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise, and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True the trained model. False is used for posterior_sample_and_average, True
...@@ -2013,7 +2013,7 @@ class LFADS(object): ...@@ -2013,7 +2013,7 @@ class LFADS(object):
LFADS generates a number of outputs for each sample, and these are all LFADS generates a number of outputs for each sample, and these are all
saved. They are: saved. They are:
The mean and variance of the prior of g0. The mean and variance of the prior of g0.
The control inputs (if enabled) The control inputs (if enabled).
The initial conditions, g0, for all examples. The initial conditions, g0, for all examples.
The generator states for all time. The generator states for all time.
The factors for all time. The factors for all time.
...@@ -2148,7 +2148,7 @@ class LFADS(object): ...@@ -2148,7 +2148,7 @@ class LFADS(object):
"""Randomly spikify underlying rates according a Poisson distribution """Randomly spikify underlying rates according a Poisson distribution
Args: Args:
rates_bxtxd: a numpy tensor with shape: rates_bxtxd: A numpy tensor with shape:
Returns: Returns:
A numpy array with the same shape as rates_bxtxd, but with the event A numpy array with the same shape as rates_bxtxd, but with the event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment