Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aa2b818e
Unverified
Commit
aa2b818e
authored
Feb 09, 2018
by
David Sussillo
Committed by
GitHub
Feb 09, 2018
Browse files
Merge pull request #3354 from sussillo/master
Fix simple errors in synthetic examples in lfads.
parents
3022f945
7a49266b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
25 deletions
+36
-25
research/lfads/README.md
research/lfads/README.md
+13
-2
research/lfads/synth_data/generate_itb_data.py
research/lfads/synth_data/generate_itb_data.py
+6
-6
research/lfads/synth_data/generate_labeled_rnn_data.py
research/lfads/synth_data/generate_labeled_rnn_data.py
+7
-7
research/lfads/synth_data/run_generate_synth_data.sh
research/lfads/synth_data/run_generate_synth_data.sh
+7
-7
research/lfads/synth_data/synthetic_data_utils.py
research/lfads/synth_data/synthetic_data_utils.py
+3
-3
No files found.
research/lfads/README.md
View file @
aa2b818e
...
...
@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need:
*
**TensorFlow**
version 1.
2.1
(
[
install
](
https://www.tensorflow.org/install/
)
) -
*
**TensorFlow**
version 1.
5
(
[
install
](
https://www.tensorflow.org/install/
)
) -
*
**NumPy, SciPy, Matplotlib**
(
[
install SciPy stack
](
https://www.scipy.org/install.html
)
, contains all of them)
*
**h5py**
(
[
install
](
https://pypi.python.org/pypi/h5py
)
)
...
...
@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \
--output_filename_stem
=
""
\
--ic_prior_var_max
=
0.1
\
--prior_ar_atau
=
10.0
\
--do_train_io_only
=
false
--do_train_io_only
=
false
\
--do_train_encoder_only
=
false
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$
python run_lfads.py
--kind
=
train
\
--data_dir
=
/tmp/rnn_synth_data_v1.0/
\
--data_filename_stem
=
gaussian_chaotic_rnn_no_inputs
\
--lfads_save_dir
=
/tmp/lfads_chaotic_rnn_inputs_g2p5
\
--co_dim
=
1
\
--factors_dim
=
20
\
--output_dist
=
gaussian
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$
python run_lfads.py
--kind
=
train
\
...
...
research/lfads/synth_data/generate_itb_data.py
View file @
aa2b818e
...
...
@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions")
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
"Percentage of train vs validation trials"
)
flags
.
DEFINE_integer
(
"n
spikif
ications"
,
5
,
flags
.
DEFINE_integer
(
"n
repl
ications"
,
5
,
"Number of spikifications of the same underlying rates."
)
flags
.
DEFINE_float
(
"tau"
,
0.025
,
"Time constant of RNN"
)
flags
.
DEFINE_float
(
"dt"
,
0.010
,
"Time bin"
)
...
...
@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T
=
FLAGS
.
T
C
=
FLAGS
.
C
N
=
FLAGS
.
N
# must be same N as in trained model (provided example is N = 50)
n
spikif
ications
=
FLAGS
.
n
spikif
ications
E
=
n
spikif
ications
*
C
# total number of trials
n
repl
ications
=
FLAGS
.
n
repl
ications
E
=
n
repl
ications
*
C
# total number of trials
train_percentage
=
FLAGS
.
train_percentage
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
batch_size
=
1
# gives one example per ntrial
...
...
@@ -144,7 +144,7 @@ with tf.Session() as sess:
outputs_t_bxn
=
np
.
squeeze
(
np
.
asarray
(
outputs_t_bxn
))
r_sxt
=
np
.
dot
(
P_nxn
,
states_nxt
)
for
s
in
xrange
(
n
spikif
ications
):
for
s
in
xrange
(
n
repl
ications
):
data_e
.
append
(
r_sxt
)
u_e
.
append
(
u_1xt
)
outs_e
.
append
(
outputs_t_bxn
)
...
...
@@ -154,7 +154,7 @@ with tf.Session() as sess:
spiking_data_e
=
spikify_data
(
truth_data_e
,
rng
,
dt
=
FLAGS
.
dt
,
max_firing_rate
=
FLAGS
.
max_firing_rate
)
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
n
spikif
ications
)
n
repl
ications
)
data_train_truth
,
data_valid_truth
=
split_list_by_inds
(
truth_data_e
,
train_inds
,
...
...
@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth,
'train_data'
:
data_train_spiking
,
'valid_data'
:
data_valid_spiking
,
'train_percentage'
:
train_percentage
,
'n
spikif
ications'
:
n
spikif
ications
,
'n
repl
ications'
:
n
repl
ications
,
'dt'
:
FLAGS
.
dt
,
'u_std'
:
FLAGS
.
u_std
,
'max_firing_rate'
:
FLAGS
.
max_firing_rate
,
...
...
research/lfads/synth_data/generate_labeled_rnn_data.py
View file @
aa2b818e
...
...
@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions")
flags
.
DEFINE_integer
(
"N"
,
50
,
"Number of units for the RNN"
)
flags
.
DEFINE_float
(
"train_percentage"
,
4.0
/
5.0
,
"Percentage of train vs validation trials"
)
flags
.
DEFINE_integer
(
"n
spikif
ications"
,
10
,
flags
.
DEFINE_integer
(
"n
repl
ications"
,
10
,
"Number of spikifications of the same underlying rates."
)
flags
.
DEFINE_float
(
"g"
,
1.5
,
"Complexity of dynamics"
)
flags
.
DEFINE_float
(
"x0_std"
,
1.0
,
...
...
@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
T
=
FLAGS
.
T
C
=
FLAGS
.
C
N
=
FLAGS
.
N
n
spikif
ications
=
FLAGS
.
n
spikif
ications
E
=
n
spikif
ications
*
C
n
repl
ications
=
FLAGS
.
n
repl
ications
E
=
n
repl
ications
*
C
train_percentage
=
FLAGS
.
train_percentage
ntimesteps
=
int
(
T
/
FLAGS
.
dt
)
...
...
@@ -77,8 +77,8 @@ condition_labels = []
condition_number
=
0
for
c
in
range
(
C
):
x0
=
FLAGS
.
x0_std
*
rng
.
randn
(
N
,
1
)
x0s
.
append
(
np
.
tile
(
x0
,
n
spikif
ications
))
for
ns
in
range
(
n
spikif
ications
):
x0s
.
append
(
np
.
tile
(
x0
,
n
repl
ications
))
for
ns
in
range
(
n
repl
ications
):
condition_labels
.
append
(
condition_number
)
condition_number
+=
1
x0s
=
np
.
concatenate
(
x0s
,
axis
=
1
)
...
...
@@ -107,7 +107,7 @@ for trial in xrange(E):
# split into train and validation sets
train_inds
,
valid_inds
=
get_train_n_valid_inds
(
E
,
train_percentage
,
n
spikif
ications
)
n
repl
ications
)
rates_train
,
rates_valid
=
split_list_by_inds
(
rates
,
train_inds
,
valid_inds
)
spikes_train
,
spikes_valid
=
split_list_by_inds
(
spikes
,
train_inds
,
valid_inds
)
...
...
@@ -129,7 +129,7 @@ data = {'train_truth': rates_train,
'train_ext_input'
:
np
.
array
(
ext_input_train
),
'valid_ext_input'
:
np
.
array
(
ext_input_valid
),
'train_percentage'
:
train_percentage
,
'n
spikif
ications'
:
n
spikif
ications
,
'n
repl
ications'
:
n
repl
ications
,
'dt'
:
FLAGS
.
dt
,
'P_sxn'
:
P_nxn
,
'condition_labels_train'
:
condition_labels_train
,
...
...
research/lfads/synth_data/run_generate_synth_data.sh
View file @
aa2b818e
...
...
@@ -19,22 +19,22 @@
SYNTH_PATH
=
/tmp/rnn_synth_data_v1.0/
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_no_inputs
_gaussian
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'gaussian'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
gaussian_
chaotic_rnn_no_inputs
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'gaussian'
echo
"Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g1p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g1p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g2p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
2.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_inputs_g2p5
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--S
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
2.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
20.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_multisession
--synth_data_seed
=
5
--T
=
1.0
--C
=
150
--N
=
100
--S
=
20
--npcs
=
10
--train_percentage
=
0.8
--n
spikif
ications
=
40
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
python generate_chaotic_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnn_multisession
--synth_data_seed
=
5
--T
=
1.0
--C
=
150
--N
=
100
--S
=
20
--npcs
=
10
--train_percentage
=
0.8
--n
repl
ications
=
40
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--input_magnitude
=
0.0
--max_firing_rate
=
30.0
--noise_type
=
'poisson'
echo
"Generating Integration-to-bound RNN data"
python generate_itb_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
itb_rnn
--u_std
=
0.25
--checkpoint_path
=
SAMPLE_CHECKPOINT
--synth_data_seed
=
5
--T
=
1.0
--C
=
800
--N
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
5
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
python generate_itb_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
itb_rnn
--u_std
=
0.25
--checkpoint_path
=
SAMPLE_CHECKPOINT
--synth_data_seed
=
5
--T
=
1.0
--C
=
800
--N
=
50
--train_percentage
=
0.8
--n
repl
ications
=
5
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
echo
"Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnns_labeled
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--train_percentage
=
0.8
--n
spikif
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
python generate_labeled_rnn_data.py
--save_dir
=
$SYNTH_PATH
--datafile_name
=
chaotic_rnns_labeled
--synth_data_seed
=
5
--T
=
1.0
--C
=
400
--N
=
50
--train_percentage
=
0.8
--n
repl
ications
=
10
--g
=
1.5
--x0_std
=
1.0
--tau
=
0.025
--dt
=
0.01
--max_firing_rate
=
30.0
research/lfads/synth_data/synthetic_data_utils.py
View file @
aa2b818e
...
...
@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
def
get_train_n_valid_inds
(
num_trials
,
train_fraction
,
n
spikif
ications
):
def
get_train_n_valid_inds
(
num_trials
,
train_fraction
,
n
repl
ications
):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
n
spikif
ications: the number of spiking trials per initial condition
n
repl
ications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
...
...
@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if
(
i
%
n
spikif
ications
)
+
1
>
train_fraction
*
n
spikif
ications
:
if
(
i
%
n
repl
ications
)
+
1
>
train_fraction
*
n
repl
ications
:
valid_inds
.
append
(
i
)
else
:
train_inds
.
append
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment