"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "021822dd01b0ade6690ad358a46a4829de55ec84"
Unverified Commit aa2b818e authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #3354 from sussillo/master

Fix simple errors in synthetic examples in lfads.
parents 3022f945 7a49266b
...@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v ...@@ -7,7 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need: The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.2.1 ([install](https://www.tensorflow.org/install/)) - * **TensorFlow** version 1.5 ([install](https://www.tensorflow.org/install/)) -
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them) * **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py)) * **h5py** ([install](https://pypi.python.org/pypi/h5py))
...@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \ ...@@ -98,7 +98,18 @@ $ python run_lfads.py --kind=train \
--output_filename_stem="" \ --output_filename_stem="" \
--ic_prior_var_max=0.1 \ --ic_prior_var_max=0.1 \
--prior_ar_atau=10.0 \ --prior_ar_atau=10.0 \
--do_train_io_only=false --do_train_io_only=false \
--do_train_encoder_only=false
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=gaussian_chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \
--factors_dim=20 \
--output_dist=gaussian
# Run LFADS on chaotic rnn data with input pulses (g = 2.5) # Run LFADS on chaotic rnn data with input pulses (g = 2.5)
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
......
...@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions") ...@@ -39,7 +39,7 @@ flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN") flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0, flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials") "Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 5, flags.DEFINE_integer("nreplications", 5,
"Number of spikifications of the same underlying rates.") "Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN") flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin") flags.DEFINE_float("dt", 0.010, "Time bin")
...@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1) ...@@ -90,8 +90,8 @@ u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T T = FLAGS.T
C = FLAGS.C C = FLAGS.C
N = FLAGS.N # must be same N as in trained model (provided example is N = 50) N = FLAGS.N # must be same N as in trained model (provided example is N = 50)
nspikifications = FLAGS.nspikifications nreplications = FLAGS.nreplications
E = nspikifications * C # total number of trials E = nreplications * C # total number of trials
train_percentage = FLAGS.train_percentage train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt) ntimesteps = int(T / FLAGS.dt)
batch_size = 1 # gives one example per ntrial batch_size = 1 # gives one example per ntrial
...@@ -144,7 +144,7 @@ with tf.Session() as sess: ...@@ -144,7 +144,7 @@ with tf.Session() as sess:
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn)) outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
r_sxt = np.dot(P_nxn, states_nxt) r_sxt = np.dot(P_nxn, states_nxt)
for s in xrange(nspikifications): for s in xrange(nreplications):
data_e.append(r_sxt) data_e.append(r_sxt)
u_e.append(u_1xt) u_e.append(u_1xt)
outs_e.append(outputs_t_bxn) outs_e.append(outputs_t_bxn)
...@@ -154,7 +154,7 @@ with tf.Session() as sess: ...@@ -154,7 +154,7 @@ with tf.Session() as sess:
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt, spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
max_firing_rate=FLAGS.max_firing_rate) max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications) nreplications)
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e, data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
train_inds, train_inds,
...@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth, ...@@ -188,7 +188,7 @@ data = { 'train_truth': data_train_truth,
'train_data' : data_train_spiking, 'train_data' : data_train_spiking,
'valid_data' : data_valid_spiking, 'valid_data' : data_valid_spiking,
'train_percentage' : train_percentage, 'train_percentage' : train_percentage,
'nspikifications' : nspikifications, 'nreplications' : nreplications,
'dt' : FLAGS.dt, 'dt' : FLAGS.dt,
'u_std' : FLAGS.u_std, 'u_std' : FLAGS.u_std,
'max_firing_rate': FLAGS.max_firing_rate, 'max_firing_rate': FLAGS.max_firing_rate,
......
...@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions") ...@@ -40,7 +40,7 @@ flags.DEFINE_integer("C", 400, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN") flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0, flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials") "Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 10, flags.DEFINE_integer("nreplications", 10,
"Number of spikifications of the same underlying rates.") "Number of spikifications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics") flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0, flags.DEFINE_float("x0_std", 1.0,
...@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1), ...@@ -56,8 +56,8 @@ rnn_rngs = [np.random.RandomState(seed=FLAGS.synth_data_seed+1),
T = FLAGS.T T = FLAGS.T
C = FLAGS.C C = FLAGS.C
N = FLAGS.N N = FLAGS.N
nspikifications = FLAGS.nspikifications nreplications = FLAGS.nreplications
E = nspikifications * C E = nreplications * C
train_percentage = FLAGS.train_percentage train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt) ntimesteps = int(T / FLAGS.dt)
...@@ -77,8 +77,8 @@ condition_labels = [] ...@@ -77,8 +77,8 @@ condition_labels = []
condition_number = 0 condition_number = 0
for c in range(C): for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1) x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) x0s.append(np.tile(x0, nreplications))
for ns in range(nspikifications): for ns in range(nreplications):
condition_labels.append(condition_number) condition_labels.append(condition_number)
condition_number += 1 condition_number += 1
x0s = np.concatenate(x0s, axis=1) x0s = np.concatenate(x0s, axis=1)
...@@ -107,7 +107,7 @@ for trial in xrange(E): ...@@ -107,7 +107,7 @@ for trial in xrange(E):
# split into train and validation sets # split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications) nreplications)
rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds) rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds) spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
...@@ -129,7 +129,7 @@ data = {'train_truth': rates_train, ...@@ -129,7 +129,7 @@ data = {'train_truth': rates_train,
'train_ext_input' : np.array(ext_input_train), 'train_ext_input' : np.array(ext_input_train),
'valid_ext_input': np.array(ext_input_valid), 'valid_ext_input': np.array(ext_input_valid),
'train_percentage' : train_percentage, 'train_percentage' : train_percentage,
'nspikifications' : nspikifications, 'nreplications' : nreplications,
'dt' : FLAGS.dt, 'dt' : FLAGS.dt,
'P_sxn' : P_nxn, 'P_sxn' : P_nxn,
'condition_labels_train' : condition_labels_train, 'condition_labels_train' : condition_labels_train,
......
...@@ -19,22 +19,22 @@ ...@@ -19,22 +19,22 @@
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/ SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise" echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson' python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise" echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs_gaussian --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian' python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=gaussian_chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
echo "Generating chaotic rnn data with input pulses (g=1.5)" echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson' python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with input pulses (g=2.5)" echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson' python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generate the multi-session RNN data (no multi-session synth example in paper)" echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson' python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nreplications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating Integration-to-bound RNN data" echo "Generating Integration-to-bound RNN data"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nreplications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)" echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
...@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100): ...@@ -176,13 +176,13 @@ def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications): def get_train_n_valid_inds(num_trials, train_fraction, nreplications):
"""Split the numbers between 0 and num_trials-1 into two portions for """Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction. training and validation, based on the train fraction.
Args: Args:
num_trials: the number of trials num_trials: the number of trials
train_fraction: (e.g. .80) train_fraction: (e.g. .80)
nspikifications: the number of spiking trials per initial condition nreplications: the number of spiking trials per initial condition
Returns: Returns:
a 2-tuple of two lists: the training indices and validation indices a 2-tuple of two lists: the training indices and validation indices
""" """
...@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications): ...@@ -192,7 +192,7 @@ def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
# This line divides up the trials so that within one initial condition, # This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both # the randomness of spikifying the condition is shared among both
# training and validation data splits. # training and validation data splits.
if (i % nspikifications)+1 > train_fraction * nspikifications: if (i % nreplications)+1 > train_fraction * nreplications:
valid_inds.append(i) valid_inds.append(i)
else: else:
train_inds.append(i) train_inds.append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment