"tests/vscode:/vscode.git/clone" did not exist on "7ac6e286ee994270e737b70c904ea50049d53567"
Unverified Commit aa2b818e authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #3354 from sussillo/master

Fix simple errors in synthetic examples in lfads.
parents 3022f945 7a49266b
......@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.2.1 ([install](https://www.tensorflow.org/install/)) -
* **TensorFlow** version 1.5 ([install](https://www.tensorflow.org/install/)) -
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py))
......@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \
--output_filename_stem="" \
--ic_prior_var_max=0.1 \
--prior_ar_atau=10.0 \
--do_train_io_only=false
--do_train_io_only=false \
--do_train_encoder_only=false
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=gaussian_chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--output_dist=gaussian
# Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$ python run_lfads.py --kind=train \
......
......@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5,
flags.DEFINE_integer("nreplications", 5,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
......@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications
E = nspikifications * C # total number of trials
nreplications = FLAGS.nreplications
E = nreplications * C # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial
......@@ -144,7 +144,7 @@ with tf.Session() as sess:
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications):
for s in xrange(nreplications):
data_e.append(r_sxt)
u_e.append(u_1xt)
outs_e.append(outputs_t_bxn)
......@@ -154,7 +154,7 @@ with tf.Session() as sess:
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
nreplications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds,
......@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth,
'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'nreplications' : nreplications,
'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate,
......
......@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10,
flags.DEFINE_integer("nreplications", 10,
"Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0,
......@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N
nspikifications = FLAGS.nspikifications
E = nspikifications * C
nreplications = FLAGS.nreplications
E = nreplications * C
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
......@@ -77,8 +77,8 @@ condition_labels = []
condition_number = 0
for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications))
for ns in range(nspikifications):
x0s.append(np.tile(x0, nreplications))
for ns in range(nreplications):
condition_labels.append(condition_number)
condition_number += 1
x0s = np.concatenate(x0s, axis=1)
......@@ -107,7 +107,7 @@ for trial in xrange(E):
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications)
nreplications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
......@@ -129,7 +129,7 @@ data = {'train_truth': rates_train,
'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage,
'nspikifications' : nspikifications,
'nreplications' : nreplications,
'dt' : FLAGS.dt,
'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train,
......
......@@ -19,22 +19,22 @@
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs_gaussian --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=gaussian_chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nreplications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nreplications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
......@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
def get_train_n_valid_inds(num_trials, train_fraction, nreplications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
Args:
num_trials: the number of trials
train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition
nreplications: the number of spiking trials per initial condition
Returns:
a 2-tuple of two lists: the training indices and validation indices
"""
......@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications:
if (i % nreplications)+1 > train_fraction * nreplications:
valid_inds.append(i)
else:
train_inds.append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment