Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
05b7b7ee
Commit
05b7b7ee
authored
Feb 09, 2018
by
David Sussillo
Browse files
Merge pull request #2898 from cpandar/master
change to lfads to allow training of encoder weights only
parents
99400da5
9b3a7754
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
28 deletions
+63
-28
research/lfads/README.md
research/lfads/README.md
+13
-2
research/lfads/lfads.py
research/lfads/lfads.py
+15
-3
research/lfads/run_lfads.py
research/lfads/run_lfads.py
+12
-0
research/lfads/synth_data/generate_itb_data.py
research/lfads/synth_data/generate_itb_data.py
+6
-6
research/lfads/synth_data/generate_labeled_rnn_data.py
research/lfads/synth_data/generate_labeled_rnn_data.py
+7
-7
research/lfads/synth_data/run_generate_synth_data.sh
research/lfads/synth_data/run_generate_synth_data.sh
+7
-7
research/lfads/synth_data/synthetic_data_utils.py
research/lfads/synth_data/synthetic_data_utils.py
+3
-3
No files found.
research/lfads/README.md
View file @
05b7b7ee
...
@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
...
@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need:
The code is written in Python 2.7.6. You will also need:
*
**TensorFlow**
version 1.
2.1
(
[
install
](
https://www.tensorflow.org/install/
)
) -
*
**TensorFlow**
version 1.
5
(
[
install
](
https://www.tensorflow.org/install/
)
) -
*
**NumPy, SciPy, Matplotlib**
(
[
install SciPy stack
](
https://www.scipy.org/install.html
)
, contains all of them)
*
**NumPy, SciPy, Matplotlib**
(
[
install SciPy stack
](
https://www.scipy.org/install.html
)
, contains all of them)
*
**h5py**
(
[
install
](
https://pypi.python.org/pypi/h5py
)
)
*
**h5py**
(
[
install
](
https://pypi.python.org/pypi/h5py
)
)
...
@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \
...
@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \
--output_filename_stem
=
""
\
--output_filename_stem
=
""
\
--ic_prior_var_max
=
0.1
\
--ic_prior_var_max
=
0.1
\
--prior_ar_atau
=
10.0
\
--prior_ar_atau
=
10.0
\
--do_train_io_only
=
false
--do_train_io_only
=
false
\
--do_train_encoder_only
=
false
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$
python run_lfads.py
--kind
=
train
\
--data_dir
=
/tmp/rnn_synth_data_v1.0/
\
--data_filename_stem
=
gaussian_chaotic_rnn_no_inputs
\
--lfads_save_dir
=
/tmp/lfads_chaotic_rnn_inputs_g2p5
\
--co_dim
=
1
\
--factors_dim
=
20
\
--output_dist
=
gaussian
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$
python run_lfads.py
--kind
=
train
\
$
python run_lfads.py
--kind
=
train
\
...
...
research/lfads/lfads.py
View file @
05b7b7ee
...
@@ -915,13 +915,25 @@ class LFADS(object):
...
@@ -915,13 +915,25 @@ class LFADS(object):
return
return
# OPTIMIZATION
# OPTIMIZATION
if
not
self
.
hps
.
do_train_io_only
:
# train the io matrices only
if
self
.
hps
.
do_train_io_only
:
self
.
train_vars
=
tvars
=
\
self
.
train_vars
=
tvars
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
get_collection
(
'IO_transformations'
,
scope
=
tf
.
get_variable_scope
().
name
)
scope
=
tf
.
get_variable_scope
().
name
)
# train the encoder only
elif
self
.
hps
.
do_train_encoder_only
:
tvars1
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/ic_enc_*'
)
tvars2
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/z/ic_enc_*'
)
self
.
train_vars
=
tvars
=
tvars1
+
tvars2
# train all variables
else
:
else
:
self
.
train_vars
=
tvars
=
\
self
.
train_vars
=
tvars
=
\
tf
.
get_collection
(
'IO_transformations'
,
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
tf
.
get_variable_scope
().
name
)
scope
=
tf
.
get_variable_scope
().
name
)
print
(
"done."
)
print
(
"done."
)
print
(
"Model Variables (to be optimized): "
)
print
(
"Model Variables (to be optimized): "
)
...
...
research/lfads/run_lfads.py
View file @
05b7b7ee
...
@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001
...
@@ -53,6 +53,7 @@ LEARNING_RATE_STOP = 0.00001
LEARNING_RATE_N_TO_COMPARE
=
6
LEARNING_RATE_N_TO_COMPARE
=
6
INJECT_EXT_INPUT_TO_GEN
=
False
INJECT_EXT_INPUT_TO_GEN
=
False
DO_TRAIN_IO_ONLY
=
False
DO_TRAIN_IO_ONLY
=
False
DO_TRAIN_ENCODER_ONLY
=
False
DO_RESET_LEARNING_RATE
=
False
DO_RESET_LEARNING_RATE
=
False
FEEDBACK_FACTORS_OR_RATES
=
"factors"
FEEDBACK_FACTORS_OR_RATES
=
"factors"
DO_TRAIN_READIN
=
True
DO_TRAIN_READIN
=
True
...
@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
...
@@ -315,6 +316,16 @@ flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
"Train only the input (readin) and output (readout)
\
"Train only the input (readin) and output (readout)
\
affine functions."
)
affine functions."
)
# This flag is used for an experiment where one wants to know if the dynamics
# learned by the generator generalize across conditions. In that case, you might
# train up a model on one set of data, and then only further train the encoder on
# another set of data (the conditions to be tested) so that the model is forced
# to use the same dynamics to describe that data.
# If you don't care about that particular experiment, this flag should always be
# false.
flags
.
DEFINE_boolean
(
"do_train_encoder_only"
,
DO_TRAIN_ENCODER_ONLY
,
"Train only the encoder weights."
)
flags
.
DEFINE_boolean
(
"do_reset_learning_rate"
,
DO_RESET_LEARNING_RATE
,
flags
.
DEFINE_boolean
(
"do_reset_learning_rate"
,
DO_RESET_LEARNING_RATE
,
"Reset the learning rate to initial value."
)
"Reset the learning rate to initial value."
)
...
@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags):
...
@@ -550,6 +561,7 @@ def build_hyperparameter_dict(flags):
d
[
'max_grad_norm'
]
=
flags
.
max_grad_norm
d
[
'max_grad_norm'
]
=
flags
.
max_grad_norm
d
[
'cell_clip_value'
]
=
flags
.
cell_clip_value
d
[
'cell_clip_value'
]
=
flags
.
cell_clip_value
d
[
'do_train_io_only'
]
=
flags
.
do_train_io_only
d
[
'do_train_io_only'
]
=
flags
.
do_train_io_only
d
[
'do_train_encoder_only'
]
=
flags
.
do_train_encoder_only
d
[
'do_reset_learning_rate'
]
=
flags
.
do_reset_learning_rate
d
[
'do_reset_learning_rate'
]
=
flags
.
do_reset_learning_rate
d
[
'do_train_readin'
]
=
flags
.
do_train_readin
d
[
'do_train_readin'
]
=
flags
.
do_train_readin
...
...
research/lfads/synth_data/generate_itb_data.py
View file @
05b7b7ee
...
@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions")
...
@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions")
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
"Percentage of train vs validation trials"
)
"Percentage of train vs validation trials"
)
flags
.
DEFINE_integer
(
"n
spikif
ications"
,
5
,
flags
.
DEFINE_integer
(
"n
repl
ications"
,
5
,
"Number of spikifications of the same underlying rates."
)
"Number of spikifications of the same underlying rates."
)
flags
.
DEFINE_float
(
"tau"
,
0.025
,
"Time constant of RNN"
)
flags
.
DEFINE_float
(
"tau"
,
0.025
,
"Time constant of RNN"
)
flags
.
DEFINE_float
(
"dt"
,
0.010
,
"Time bin"
)
flags
.
DEFINE_float
(
"dt"
,
0.010
,
"Time bin"
)
...
@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
...
@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T
=
FLAGS
.
T
T
=
FLAGS
.
T
C
=
FLAGS
.
C
C
=
FLAGS
.
C
N
=
FLAGS
.
N
# must be same N as in trained model (provided example is N = 50)
N
=
FLAGS
.
N
# must be same N as in trained model (provided example is N = 50)
n
spikif
ications
=
FLAGS
.
n
spikif
ications
n
repl
ications
=
FLAGS
.
n
repl
ications
E
=
n
spikif
ications
*
C
# total number of trials
E
=
n
repl
ications
*
C
# total number of trials
train_percentage
=
FLAGS
.
train_percentage
train_percentage
=
FLAGS
.
train_percentage
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
batch_size
=
1
# gives one example per ntrial
batch_size
=
1
# gives one example per ntrial
...
@@ -144,7 +144,7 @@ with tf.Session() as sess:
...
@@ -144,7 +144,7 @@ with tf.Session() as sess:
outputs_t_bxn
=
np
.
squeeze
(
np
.
asarray
(
outputs_t_bxn
))
outputs_t_bxn
=
np
.
squeeze
(
np
.
asarray
(
outputs_t_bxn
))
r_sxt
=
np
.
dot
(
P_nxn
,
states_nxt
)
r_sxt
=
np
.
dot
(
P_nxn
,
states_nxt
)
for
s
in
xrange
(
n
spikif
ications
):
for
s
in
xrange
(
n
repl
ications
):
data_e
.
append
(
r_sxt
)
data_e
.
append
(
r_sxt
)
u_e
.
append
(
u_1xt
)
u_e
.
append
(
u_1xt
)
outs_e
.
append
(
outputs_t_bxn
)
outs_e
.
append
(
outputs_t_bxn
)
...
@@ -154,7 +154,7 @@ with tf.Session() as sess:
...
@@ -154,7 +154,7 @@ with tf.Session() as sess:
spiking_data_e
=
spikify_data
(
truth_data_e
,
rng
,
dt
=
FLAGS
.
dt
,
spiking_data_e
=
spikify_data
(
truth_data_e
,
rng
,
dt
=
FLAGS
.
dt
,
max_firing_rate
=
FLAGS
.
max_firing_rate
)
max_firing_rate
=
FLAGS
.
max_firing_rate
)
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
n
spikif
ications
)
n
repl
ications
)
data_train_truth
,
data_valid_truth
=
split_list_by_inds
(
truth_data_e
,
data_train_truth
,
data_valid_truth
=
split_list_by_inds
(
truth_data_e
,
train_inds
,
train_inds
,
...
@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth,
...
@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth,
'train_data'
:
data_train_spiking
,
'train_data'
:
data_train_spiking
,
'valid_data'
:
data_valid_spiking
,
'valid_data'
:
data_valid_spiking
,
'train_percentage'
:
train_percentage
,
'train_percentage'
:
train_percentage
,
'n
spikif
ications'
:
n
spikif
ications
,
'n
repl
ications'
:
n
repl
ications
,
'dt'
:
FLAGS
.
dt
,
'dt'
:
FLAGS
.
dt
,
'u_std'
:
FLAGS
.
u_std
,
'u_std'
:
FLAGS
.
u_std
,
'max_firing_rate'
:
FLAGS
.
max_firing_rate
,
'max_firing_rate'
:
FLAGS
.
max_firing_rate
,
...
...
research/lfads/synth_data/generate_labeled_rnn_data.py
View file @
05b7b7ee
...
@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions")
...
@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions")
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
"Percentage of train vs validation trials"
)
"Percentage of train vs validation trials"
)
flags
.
DEFINE_integer
(
"n
spikif
ications"
,
10
,
flags
.
DEFINE_integer
(
"n
repl
ications"
,
10
,
"Number of spikifications of the same underlying rates."
)
"Number of spikifications of the same underlying rates."
)
flags
.
DEFINE_float
(
"g"
,
1.5
,
"Complexity of dynamics"
)
flags
.
DEFINE_float
(
"g"
,
1.5
,
"Complexity of dynamics"
)
flags
.
DEFINE_float
(
"x0_std"
,
1.0
,
flags
.
DEFINE_float
(
"x0_std"
,
1.0
,
...
@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
...
@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
T
=
FLAGS
.
T
T
=
FLAGS
.
T
C
=
FLAGS
.
C
C
=
FLAGS
.
C
N
=
FLAGS
.
N
N
=
FLAGS
.
N
n
spikif
ications
=
FLAGS
.
n
spikif
ications
n
repl
ications
=
FLAGS
.
n
repl
ications
E
=
n
spikif
ications
*
C
E
=
n
repl
ications
*
C
train_percentage
=
FLAGS
.
train_percentage
train_percentage
=
FLAGS
.
train_percentage
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
...
@@ -77,8 +77,8 @@ condition_labels = []
...
@@ -77,8 +77,8 @@ condition_labels = []
condition_number
=
0
condition_number
=
0
for
c
in
range
(
C
):
for
c
in
range
(
C
):
x0
=
FLAGS
.
x0_std
*
rng
.
randn
(
N
,
1
)
x0
=
FLAGS
.
x0_std
*
rng
.
randn
(
N
,
1
)
x0s
.
append
(
np
.
tile
(
x0
,
n
spikif
ications
))
x0s
.
append
(
np
.
tile
(
x0
,
n
repl
ications
))
for
ns
in
range
(
n
spikif
ications
):
for
ns
in
range
(
n
repl
ications
):
condition_labels
.
append
(
condition_number
)
condition_labels
.
append
(
condition_number
)
condition_number
+=
1
condition_number
+=
1
x0s
=
np
.
concatenate
(
x0s
,
axis
=
1
)
x0s
=
np
.
concatenate
(
x0s
,
axis
=
1
)
...
@@ -107,7 +107,7 @@ for trial in xrange(E):
...
@@ -107,7 +107,7 @@ for trial in xrange(E):
# split into train and validation sets
# split into train and validation sets
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
n
spikif
ications
)
n
repl
ications
)
rates_train
,
rates_valid
=
split_list_by_inds
(
rates
,
train_inds
,
valid_inds
)
rates_train
,
rates_valid
=
split_list_by_inds
(
rates
,
train_inds
,
valid_inds
)
spikes_train
,
spikes_valid
=
split_list_by_inds
(
spikes
,
train_inds
,
valid_inds
)
spikes_train
,
spikes_valid
=
split_list_by_inds
(
spikes
,
train_inds
,
valid_inds
)
...
@@ -129,7 +129,7 @@ data = {'train_truth': rates_train,
...
@@ -129,7 +129,7 @@ data = {'train_truth': rates_train,
'train_ext_input'
:
np
.
array
(
ext_input_train
),
'train_ext_input'
:
np
.
array
(
ext_input_train
),
'valid_ext_input'
:
np
.
array
(
ext_input_valid
),
'valid_ext_input'
:
np
.
array
(
ext_input_valid
),
'train_percentage'
:
train_percentage
,
'train_percentage'
:
train_percentage
,
'n
spikif
ications'
:
n
spikif
ications
,
'n
repl
ications'
:
n
repl
ications
,
'dt'
:
FLAGS
.
dt
,
'dt'
:
FLAGS
.
dt
,
'P_sxn'
:
P_nxn
,
'P_sxn'
:
P_nxn
,
'condition_labels_train'
:
condition_labels_train
,
'condition_labels_train'
:
condition_labels_train
,
...
...
research/lfads/synth_data/run_generate_synth_data.sh
View file @
05b7b7ee
...
@@ -19,22 +19,22 @@
...
@@ -19,22 +19,22 @@
SYNTH_PATH
=
/tmp/rnn_synth_data_v1.0/
SYNTH_PATH
=
/tmp/rnn_synth_data_v1.0/
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
_gaussian
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'gaussian'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
gaussian_
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'gaussian'
echo
"Generating chaotic rnn data with input pulses (g=1.5)"
echo
"Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g1p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g1p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating chaotic rnn data with input pulses (g=2.5)"
echo
"Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g2p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
2.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g2p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
2.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generate the multi-session RNN data (no multi-session synth example in paper)"
echo
"Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_multisession
--synth_data_seed
=
5
--T
=
1.0
--C
=
150
--N
=
100
--S
=
20
--npcs
=
10
--train_percentage
=
0.8
--n
spikif
ications
=
40
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_multisession
--synth_data_seed
=
5
--T
=
1.0
--C
=
150
--N
=
100
--S
=
20
--npcs
=
10
--train_percentage
=
0.8
--n
repl
ications
=
40
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating Integration-to-bound RNN data"
echo
"Generating Integration-to-bound RNN data"
python generate_itb_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
itb_rnn
--u_std
=
0.25
--checkpoint_path
=
SAMPLE_CHECKPOINT
--synth_data_seed
=
5
--T
=
1.0
--C
=
800
--N
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
5
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
python generate_itb_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
itb_rnn
--u_std
=
0.25
--checkpoint_path
=
SAMPLE_CHECKPOINT
--synth_data_seed
=
5
--T
=
1.0
--C
=
800
--N
=
50
--train_percentage
=
0.8
--n
repl
ications
=
5
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
echo
"Generating chaotic rnn data with external input labels (no external input labels example in paper)"
echo
"Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnns_labeled
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
python generate_labeled_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnns_labeled
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
research/lfads/synth_data/synthetic_data_utils.py
View file @
05b7b7ee
...
@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
...
@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
def
get_train_n_valid_inds
(
num_trials
,
train_fraction
,
n
spikif
ications
):
def
get_train_n_valid_inds
(
num_trials
,
train_fraction
,
n
repl
ications
):
"""Split the numbers between 0 and num_trials-1 into two portions for
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
training and validation, based on the train fraction.
Args:
Args:
num_trials: the number of trials
num_trials: the number of trials
train_fraction: (e.g. .80)
train_fraction: (e.g. .80)
n
spikif
ications: the number of spiking trials per initial condition
n
repl
ications: the number of spiking trials per initial condition
Returns:
Returns:
a 2-tuple of two lists: the training indices and validation indices
a 2-tuple of two lists: the training indices and validation indices
"""
"""
...
@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
...
@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
# This line divides up the trials so that within one initial condition,
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
# training and validation data splits.
if
(
i
%
n
spikif
ications
)
+
1
>
train_fraction
*
n
spikif
ications
:
if
(
i
%
n
repl
ications
)
+
1
>
train_fraction
*
n
repl
ications
:
valid_inds
.
append
(
i
)
valid_inds
.
append
(
i
)
else
:
else
:
train_inds
.
append
(
i
)
train_inds
.
append
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment