Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
730b778e
Unverified
Commit
730b778e
authored
Jul 16, 2022
by
Alexa Nguyen
Committed by
GitHub
Jul 15, 2022
Browse files
Fix typos and indentation in lfads.py (#10699)
parent
bf868b99
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
45 deletions
+45
-45
research/lfads/lfads.py
research/lfads/lfads.py
+45
-45
No files found.
research/lfads/lfads.py
View file @
730b778e
...
@@ -85,13 +85,13 @@ class GRU(object):
...
@@ -85,13 +85,13 @@ class GRU(object):
"""Create a GRU object.
"""Create a GRU object.
Args:
Args:
num_units: Number of units in the GRU
num_units: Number of units in the GRU
.
forget_bias (optional): Hack to help learning.
forget_bias (optional): Hack to help learning.
weight_scale (optional):
w
eights are scaled by ws/sqrt(#inputs), with
weight_scale (optional):
W
eights are scaled by ws/sqrt(#inputs), with
ws being the weight scale.
ws being the weight scale.
clip_value (optional):
i
f the recurrent values grow above this value,
clip_value (optional):
I
f the recurrent values grow above this value,
clip them.
clip them.
collections (optional): List of additonal collections variables should
collections (optional): List of addit
i
onal collections variables should
belong to.
belong to.
"""
"""
self
.
_num_units
=
num_units
self
.
_num_units
=
num_units
...
@@ -171,17 +171,17 @@ class GenGRU(object):
...
@@ -171,17 +171,17 @@ class GenGRU(object):
"""Create a GRU object.
"""Create a GRU object.
Args:
Args:
num_units: Number of units in the GRU
num_units: Number of units in the GRU
.
forget_bias (optional): Hack to help learning.
forget_bias (optional): Hack to help learning.
input_weight_scale (optional):
w
eights are scaled ws/sqrt(#inputs), with
input_weight_scale (optional):
W
eights are scaled ws/sqrt(#inputs), with
ws being the weight scale.
ws being the weight scale.
rec_weight_scale (optional):
w
eights are scaled ws/sqrt(#inputs),
rec_weight_scale (optional):
W
eights are scaled ws/sqrt(#inputs),
with ws being the weight scale.
with ws being the weight scale.
clip_value (optional):
i
f the recurrent values grow above this value,
clip_value (optional):
I
f the recurrent values grow above this value,
clip them.
clip them.
input_collections (optional): List of additonal collections variables
input_collections (optional): List of addit
i
onal collections variables
that input->rec weights should belong to.
that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables
recurrent_collections (optional): List of addit
i
onal collections variables
that rec->rec weights should belong to.
that rec->rec weights should belong to.
"""
"""
self
.
_num_units
=
num_units
self
.
_num_units
=
num_units
...
@@ -271,7 +271,7 @@ class LFADS(object):
...
@@ -271,7 +271,7 @@ class LFADS(object):
various factors, such as an initial condition, a generative
various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low
dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors.
dimensional description of the observed data, called the factors.
Addit
o
inally, the observations have a noise model (in this case
Additi
o
nally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
event counts).
...
@@ -291,8 +291,8 @@ class LFADS(object):
...
@@ -291,8 +291,8 @@ class LFADS(object):
Args:
Args:
hps: The dictionary of hyper parameters.
hps: The dictionary of hyper parameters.
kind:
t
he type of model to build (see above).
kind:
T
he type of model to build (see above).
datasets:
a
dictionary of named data_dictionaries, see top of lfads.py
datasets:
A
dictionary of named data_dictionaries, see top of lfads.py
"""
"""
print
(
"Building graph..."
)
print
(
"Building graph..."
)
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'posterior_push_mean'
,
all_kinds
=
[
'train'
,
'posterior_sample_and_average'
,
'posterior_push_mean'
,
...
@@ -1032,8 +1032,8 @@ class LFADS(object):
...
@@ -1032,8 +1032,8 @@ class LFADS(object):
Args:
Args:
train_name: The key into the datasets, to set the tf.case statement for
train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices.
the proper readin / readout matrices.
data_bxtxd: The data tensor
data_bxtxd: The data tensor
.
ext_input_bxtxi (optional): The external input tensor
ext_input_bxtxi (optional): The external input tensor
.
keep_prob: The drop out keep probability.
keep_prob: The drop out keep probability.
Returns:
Returns:
...
@@ -1066,7 +1066,7 @@ class LFADS(object):
...
@@ -1066,7 +1066,7 @@ class LFADS(object):
# examples x # time steps x # dimensions
# examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions
# examples x # time steps x # external input dimensions
batch_size:
The size of the batch to return
batch_size: The size of the batch to return
.
example_idxs (optional): The example indices used to select examples.
example_idxs (optional): The example indices used to select examples.
Returns:
Returns:
...
@@ -1123,8 +1123,8 @@ class LFADS(object):
...
@@ -1123,8 +1123,8 @@ class LFADS(object):
is managed by drawing randomly from 1:nexamples.
is managed by drawing randomly from 1:nexamples.
Args:
Args:
nexamples:
n
umber of examples to randomize
nexamples:
N
umber of examples to randomize
.
batch_size:
n
umber of elements in batch
batch_size:
N
umber of elements in batch
.
Returns:
Returns:
The randomized, properly shaped indicies.
The randomized, properly shaped indicies.
...
@@ -1148,7 +1148,7 @@ class LFADS(object):
...
@@ -1148,7 +1148,7 @@ class LFADS(object):
enough to pick up dynamics that you may not want.
enough to pick up dynamics that you may not want.
Args:
Args:
data_bxtxd:
n
umpy array of spike count data to be shuffled.
data_bxtxd:
N
umpy array of spike count data to be shuffled.
Returns:
Returns:
S_bxtxd, a numpy array with the same dimensions and contents as
S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately.
data_bxtxd, but shuffled appropriately.
...
@@ -1231,7 +1231,7 @@ class LFADS(object):
...
@@ -1231,7 +1231,7 @@ class LFADS(object):
Args:
Args:
datasets: A dict of data dicts. The dataset dict is simply a
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional):
The batch_size to use
batch_size (optional): The batch_size to use
.
do_save_ckpt (optional): Should the routine save a checkpoint on this
do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch?
training epoch?
...
@@ -1283,7 +1283,7 @@ class LFADS(object):
...
@@ -1283,7 +1283,7 @@ class LFADS(object):
name(string)-> data dictionary mapping (See top of lfads.py).
name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in
ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call.
the tf.session.run() call.
batch_size (optional):
The batch_size to use
batch_size (optional): The batch_size to use
.
do_collect (optional): Should the routine collect all session.run
do_collect (optional): Should the routine collect all session.run
output as a list, and return it?
output as a list, and return it?
keep_prob (optional): The dropout keep probability.
keep_prob (optional): The dropout keep probability.
...
@@ -1966,16 +1966,16 @@ class LFADS(object):
...
@@ -1966,16 +1966,16 @@ class LFADS(object):
saved. They are:
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The control inputs (if enabled)
.
The initial conditions, g0, for all examples.
The initial conditions, g0, for all examples.
The generator states for all time.
The generator states for all time.
The factors for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
Args:
datasets:
a
dictionary of named data_dictionaries, see top of lfads.py
datasets:
A
dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
output_fname: a file name stem for the output files.
push_mean:
i
f False (default), generates batch_size samples for each trial
push_mean:
I
f False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
the trained model. False is used for posterior_sample_and_average, True
...
@@ -2013,7 +2013,7 @@ class LFADS(object):
...
@@ -2013,7 +2013,7 @@ class LFADS(object):
LFADS generates a number of outputs for each sample, and these are all
LFADS generates a number of outputs for each sample, and these are all
saved. They are:
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of the prior of g0.
The control inputs (if enabled)
The control inputs (if enabled)
.
The initial conditions, g0, for all examples.
The initial conditions, g0, for all examples.
The generator states for all time.
The generator states for all time.
The factors for all time.
The factors for all time.
...
@@ -2148,7 +2148,7 @@ class LFADS(object):
...
@@ -2148,7 +2148,7 @@ class LFADS(object):
"""Randomly spikify underlying rates according a Poisson distribution
"""Randomly spikify underlying rates according a Poisson distribution
Args:
Args:
rates_bxtxd:
a
numpy tensor with shape:
rates_bxtxd:
A
numpy tensor with shape:
Returns:
Returns:
A numpy array with the same shape as rates_bxtxd, but with the event
A numpy array with the same shape as rates_bxtxd, but with the event
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment