"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "7a77abd9b6267dc0020a60b424b4748fc22790bb"
Commit 0826ef88 authored by David Sussillo's avatar David Sussillo
Browse files

Made PCA bias work for input alignment. Fixed output matrix to pinv, added...

Made PCA bias work for input alignment. Fixed output matrix to pinv, added gaussian example, fixed output gaussian param inits, updated README
parent 09f32cea
...@@ -36,10 +36,10 @@ These synthetic datasets are provided 1. to gain insight into how the LFADS algo ...@@ -36,10 +36,10 @@ These synthetic datasets are provided 1. to gain insight into how the LFADS algo
## Train an LFADS model ## Train an LFADS model
Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 8) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters. Now that we have our example datasets, we can train some models! To spin up an LFADS model on the synthetic data, run any of the following commands. For the examples that are in the paper, the important hyperparameters are roughly replicated. Most hyperparameters are insensitive to small changes or won't ever be changed unless you want a very fine level of control. In the first example, all hyperparameter flags are enumerated for easy copy-pasting, but for the rest of the examples only the most important flags (~the first 9) are specified for brevity. For a full list of flags, their descriptions, and their default values, refer to the top of `run_lfads.py`. Please see Table 1 in the Online Methods of the associated paper for definitions of the most important hyperparameters.
```sh ```sh
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) # Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with spiking noise
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \ --data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \ --data_filename_stem=chaotic_rnn_no_inputs \
...@@ -106,14 +106,16 @@ $ python run_lfads.py --kind=train \ ...@@ -106,14 +106,16 @@ $ python run_lfads.py --kind=train \
--data_filename_stem=chaotic_rnn_inputs_g2p5 \ --data_filename_stem=chaotic_rnn_inputs_g2p5 \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \ --lfads_save_dir=/tmp/lfads_chaotic_rnn_inputs_g2p5 \
--co_dim=1 \ --co_dim=1 \
--factors_dim=20 --factors_dim=20 \
--output_dist=poisson
# Run LFADS on multi-session RNN data # Run LFADS on multi-session RNN data
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \ --data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_multisession \ --data_filename_stem=chaotic_rnn_multisession \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \ --lfads_save_dir=/tmp/lfads_chaotic_rnn_multisession \
--factors_dim=10 --factors_dim=10 \
--output_dist=poisson
# Run LFADS on integration to bound model data # Run LFADS on integration to bound model data
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
...@@ -122,7 +124,8 @@ $ python run_lfads.py --kind=train \ ...@@ -122,7 +124,8 @@ $ python run_lfads.py --kind=train \
--lfads_save_dir=/tmp/lfads_itb_rnn \ --lfads_save_dir=/tmp/lfads_itb_rnn \
--co_dim=1 \ --co_dim=1 \
--factors_dim=20 \ --factors_dim=20 \
--controller_input_lag=0 --controller_input_lag=0 \
--output_dist=poisson
# Run LFADS on chaotic RNN data with labels # Run LFADS on chaotic RNN data with labels
$ python run_lfads.py --kind=train \ $ python run_lfads.py --kind=train \
...@@ -132,7 +135,20 @@ $ python run_lfads.py --kind=train \ ...@@ -132,7 +135,20 @@ $ python run_lfads.py --kind=train \
--co_dim=0 \ --co_dim=0 \
--factors_dim=20 \ --factors_dim=20 \
--controller_input_lag=0 \ --controller_input_lag=0 \
--ext_input_dim=1 --ext_input_dim=1 \
--output_dist=poisson
# Run LFADS on chaotic rnn data with no input pulses (g = 1.5) with Gaussian noise
$ python run_lfads.py --kind=train \
--data_dir=/tmp/rnn_synth_data_v1.0/ \
--data_filename_stem=chaotic_rnn_no_inputs \
--lfads_save_dir=/tmp/lfads_chaotic_rnn_no_inputs \
--co_dim=0 \
--factors_dim=20 \
--ext_input_dim=0 \
--controller_input_lag=1 \
--output_dist=gaussian \
``` ```
......
...@@ -43,6 +43,10 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys: ...@@ -43,6 +43,10 @@ The nested dictionary is the DATA DICTIONARY, which has the following keys:
output adapter for each dataset. These matrices, if provided, must be of output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag. on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials, If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then (say simulated, testing data, as in the example shipped with the paper), then
...@@ -356,18 +360,36 @@ class LFADS(object): ...@@ -356,18 +360,36 @@ class LFADS(object):
for d, name in enumerate(dataset_names): for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name] data_dim = hps.dataset_dims[name]
in_mat_cxf = None in_mat_cxf = None
in_bias_1xf = None
align_bias_1xc = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
print("Using alignment matrix provided for dataset:", name) print("Using alignment matrix provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""% (data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0], (data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1])) in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment bias provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
raise ValueError("""Alignment bias must have dimensions %d
(data_dim), but currently has %d."""%
(data_dim, in_mat_cxf.shape[0]))
if in_mat_cxf is not None and align_bias_1xc is not None:
# (data - alignment_bias) * W_in
# data * W_in - alignment_bias * W_in
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True, in_fac_lin = init_linear(data_dim, used_in_factors_dim, do_bias=True,
mat_init_value=in_mat_cxf, mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss, identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name, normalized=False, name="x_2_infac_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
...@@ -387,13 +409,17 @@ class LFADS(object): ...@@ -387,13 +409,17 @@ class LFADS(object):
dataset = datasets[name] dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
out_mat_cxf = None out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
out_mat_cxf = in_mat_cxf.T out_mat_fxc = np.linalg.pinv(in_mat_cxf)
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
if hps.output_dist == 'poisson': if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True, out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
identity_if_possible=out_identity_if_poss, identity_if_possible=out_identity_if_poss,
normalized=False, normalized=False,
name="fac_2_logrates_"+name, name="fac_2_logrates_"+name,
...@@ -403,13 +429,19 @@ class LFADS(object): ...@@ -403,13 +429,19 @@ class LFADS(object):
elif hps.output_dist == 'gaussian': elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \ out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True, init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
normalized=False, normalized=False,
name="fac_2_means_"+name, name="fac_2_means_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
bias_init_value = np.ones([1, data_dim]).astype(np.float32)
out_fac_lin_logvar = \ out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True, init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_cxf, mat_init_value=mat_init_value,
bias_init_value=bias_init_value,
normalized=False, normalized=False,
name="fac_2_logvars_"+name, name="fac_2_logvars_"+name,
collections=['IO_transformations']) collections=['IO_transformations'])
......
...@@ -24,7 +24,7 @@ from utils import write_datasets ...@@ -24,7 +24,7 @@ from utils import write_datasets
from synthetic_data_utils import add_alignment_projections, generate_data from synthetic_data_utils import add_alignment_projections, generate_data
from synthetic_data_utils import generate_rnn, get_train_n_valid_inds from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
from synthetic_data_utils import nparray_and_transpose from synthetic_data_utils import nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.signal import scipy.signal
...@@ -37,6 +37,7 @@ flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", ...@@ -37,6 +37,7 @@ flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
"Directory for saving data.") "Directory for saving data.")
flags.DEFINE_string("datafile_name", "thits_data", flags.DEFINE_string("datafile_name", "thits_data",
"Name of data file for input case.") "Name of data file for input case.")
flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 100, "Number of conditions") flags.DEFINE_integer("C", 100, "Number of conditions")
...@@ -45,8 +46,8 @@ flags.DEFINE_integer("S", 50, "Number of sampled units from RNN") ...@@ -45,8 +46,8 @@ flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.") flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
flags.DEFINE_float("train_percentage", 4.0/5.0, flags.DEFINE_float("train_percentage", 4.0/5.0,
"Percentage of train vs validation trials") "Percentage of train vs validation trials")
flags.DEFINE_integer("nspikifications", 40, flags.DEFINE_integer("nreplications", 40,
"Number of spikifications of the same underlying rates.") "Number of noise replications of the same underlying rates.")
flags.DEFINE_float("g", 1.5, "Complexity of dynamics") flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
flags.DEFINE_float("x0_std", 1.0, flags.DEFINE_float("x0_std", 1.0,
"Volume from which to pull initial conditions (affects diversity of dynamics.") "Volume from which to pull initial conditions (affects diversity of dynamics.")
...@@ -73,8 +74,8 @@ C = FLAGS.C ...@@ -73,8 +74,8 @@ C = FLAGS.C
N = FLAGS.N N = FLAGS.N
S = FLAGS.S S = FLAGS.S
input_magnitude = FLAGS.input_magnitude input_magnitude = FLAGS.input_magnitude
nspikifications = FLAGS.nspikifications nreplications = FLAGS.nreplications
E = nspikifications * C # total number of trials E = nreplications * C # total number of trials
# S is the number of measurements in each datasets, w/ each # S is the number of measurements in each datasets, w/ each
# dataset having a different set of observations. # dataset having a different set of observations.
ndatasets = N/S # ok if rounded down ndatasets = N/S # ok if rounded down
...@@ -87,9 +88,9 @@ rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate) ...@@ -87,9 +88,9 @@ rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
# Check to make sure the RNN is the one we used in the paper. # Check to make sure the RNN is the one we used in the paper.
if N == 50: if N == 50:
assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?' assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
rem_check = nspikifications * train_percentage rem_check = nreplications * train_percentage
assert abs(rem_check - int(rem_check)) < 1e-8, \ assert abs(rem_check - int(rem_check)) < 1e-8, \
'Train percentage * nspikifications should be integral number.' 'Train percentage * nreplications should be integral number.'
# Initial condition generation, and condition label generation. This # Initial condition generation, and condition label generation. This
...@@ -100,9 +101,9 @@ x0s = [] ...@@ -100,9 +101,9 @@ x0s = []
condition_labels = [] condition_labels = []
for c in range(C): for c in range(C):
x0 = FLAGS.x0_std * rng.randn(N, 1) x0 = FLAGS.x0_std * rng.randn(N, 1)
x0s.append(np.tile(x0, nspikifications)) # replicate x0 nspikifications times x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
# replicate the condition label nspikifications times # replicate the condition label nreplications times
for ns in range(nspikifications): for ns in range(nreplications):
condition_labels.append(condition_number) condition_labels.append(condition_number)
condition_number += 1 condition_number += 1
x0s = np.concatenate(x0s, axis=1) x0s = np.concatenate(x0s, axis=1)
...@@ -113,8 +114,8 @@ for n in range(ndatasets): ...@@ -113,8 +114,8 @@ for n in range(ndatasets):
print(n+1, " of ", ndatasets) print(n+1, " of ", ndatasets)
# First generate all firing rates. in the next loop, generate all # First generate all firing rates. in the next loop, generate all
# spikifications this allows the random state for rate generation to be # replications this allows the random state for rate generation to be
# independent of n_spikifications. # independent of n_replications.
dataset_name = 'dataset_N' + str(N) + '_S' + str(S) dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
if S < N: if S < N:
dataset_name += '_n' + str(n+1) dataset_name += '_n' + str(n+1)
...@@ -136,17 +137,23 @@ for n in range(ndatasets): ...@@ -136,17 +137,23 @@ for n in range(ndatasets):
generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
input_magnitude=input_magnitude, input_magnitude=input_magnitude,
input_times=input_times) input_times=input_times)
spikes = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
# split into train and validation sets if FLAGS.noise_type == "poisson":
noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
elif FLAGS.noise_type == "gaussian":
noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
else:
raise ValueError("Only noise types supported are poisson or gaussian")
# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
nspikifications) nreplications)
# Split the data, inputs, labels and times into train vs. validation. # Split the data, inputs, labels and times into train vs. validation.
rates_train, rates_valid = \ rates_train, rates_valid = \
split_list_by_inds(rates, train_inds, valid_inds) split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = \ noisy_data_train, noisy_data_valid = \
split_list_by_inds(spikes, train_inds, valid_inds) split_list_by_inds(noisy_data, train_inds, valid_inds)
input_train, inputs_valid = \ input_train, inputs_valid = \
split_list_by_inds(inputs, train_inds, valid_inds) split_list_by_inds(inputs, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = \ condition_labels_train, condition_labels_valid = \
...@@ -154,25 +161,25 @@ for n in range(ndatasets): ...@@ -154,25 +161,25 @@ for n in range(ndatasets):
input_times_train, input_times_valid = \ input_times_train, input_times_valid = \
split_list_by_inds(input_times, train_inds, valid_inds) split_list_by_inds(input_times, train_inds, valid_inds)
# Turn rates, spikes, and input into numpy arrays. # Turn rates, noisy_data, and input into numpy arrays.
rates_train = nparray_and_transpose(rates_train) rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid) rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train) noisy_data_train = nparray_and_transpose(noisy_data_train)
spikes_valid = nparray_and_transpose(spikes_valid) noisy_data_valid = nparray_and_transpose(noisy_data_valid)
input_train = nparray_and_transpose(input_train) input_train = nparray_and_transpose(input_train)
inputs_valid = nparray_and_transpose(inputs_valid) inputs_valid = nparray_and_transpose(inputs_valid)
# Note that we put these 'truth' rates and input into this # Note that we put these 'truth' rates and input into this
# structure, the only data that is used in LFADS are the spike # structure, the only data that is used in LFADS are the noisy
# trains. The rest is either for printing or posterity. # data e.g. spike trains. The rest is either for printing or posterity.
data = {'train_truth': rates_train, data = {'train_truth': rates_train,
'valid_truth': rates_valid, 'valid_truth': rates_valid,
'input_train_truth' : input_train, 'input_train_truth' : input_train,
'input_valid_truth' : inputs_valid, 'input_valid_truth' : inputs_valid,
'train_data' : spikes_train, 'train_data' : noisy_data_train,
'valid_data' : spikes_valid, 'valid_data' : noisy_data_valid,
'train_percentage' : train_percentage, 'train_percentage' : train_percentage,
'nspikifications' : nspikifications, 'nreplications' : nreplications,
'dt' : rnn['dt'], 'dt' : rnn['dt'],
'input_magnitude' : input_magnitude, 'input_magnitude' : input_magnitude,
'input_times_train' : input_times_train, 'input_times_train' : input_times_train,
......
...@@ -18,20 +18,23 @@ ...@@ -18,20 +18,23 @@
SYNTH_PATH=/tmp/rnn_synth_data_v1.0/ SYNTH_PATH=/tmp/rnn_synth_data_v1.0/
echo "Generating chaotic rnn data with no input pulses (g=1.5)" echo "Generating chaotic rnn data with no input pulses (g=1.5) with spiking noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with input pulses (g=1.5)" echo "Generating chaotic rnn data with no input pulses (g=1.5) with Gaussian noise"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs_gaussian --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='gaussian'
echo "Generating chaotic rnn data with input pulses (g=2.5)" echo "Generating chaotic rnn data with input pulses (g=1.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g1p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generate the multi-session RNN data (no multi-session synth example in paper)" echo "Generating chaotic rnn data with input pulses (g=2.5)"
python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_inputs_g2p5 --synth_data_seed=5 --T=1.0 --C=400 --N=50 --S=50 --train_percentage=0.8 --nspikifications=10 --g=2.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=20.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating Integration-to-bound RNN data" echo "Generate the multi-session RNN data (no multi-session synth example in paper)"
python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_multisession --synth_data_seed=5 --T=1.0 --C=150 --N=100 --S=20 --npcs=10 --train_percentage=0.8 --nspikifications=40 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --input_magnitude=0.0 --max_firing_rate=30.0 --noise_type='poisson'
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)" echo "Generating Integration-to-bound RNN data"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0 python generate_itb_data.py --save_dir=$SYNTH_PATH --datafile_name=itb_rnn --u_std=0.25 --checkpoint_path=SAMPLE_CHECKPOINT --synth_data_seed=5 --T=1.0 --C=800 --N=50 --train_percentage=0.8 --nspikifications=5 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
echo "Generating chaotic rnn data with external input labels (no external input labels example in paper)"
python generate_labeled_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnns_labeled --synth_data_seed=5 --T=1.0 --C=400 --N=50 --train_percentage=0.8 --nspikifications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --max_firing_rate=30.0
...@@ -136,7 +136,6 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100): ...@@ -136,7 +136,6 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
sampled from the underlying poisson process. sampled from the underlying poisson process.
""" """
spikifies_data_e = []
E = len(data_e) E = len(data_e)
spikes_e = [] spikes_e = []
for e in range(E): for e in range(E):
...@@ -152,6 +151,32 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100): ...@@ -152,6 +151,32 @@ def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
return spikes_e return spikes_e
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply gaussian noise to a continuous dataset whose values are between
0.0 and 1.0
Args:
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
Returns:
spikified_data_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
"""
E = len(data_e)
mfr = max_firing_rate
gauss_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
gauss_e.append(noisy_data)
return gauss_e
def get_train_n_valid_inds(num_trials, train_fraction, nspikifications): def get_train_n_valid_inds(num_trials, train_fraction, nspikifications):
"""Split the numbers between 0 and num_trials-1 into two portions for """Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction. training and validation, based on the train fraction.
...@@ -295,6 +320,8 @@ def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None): ...@@ -295,6 +320,8 @@ def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
W_chxp, _, _, _ = \ W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T) np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp dataset['alignment_matrix_cxf'] = W_chxp
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
do_debug_plot = False do_debug_plot = False
if do_debug_plot: if do_debug_plot:
......
...@@ -82,9 +82,9 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False, ...@@ -82,9 +82,9 @@ def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
return tf.matmul(x, W) return tf.matmul(x, W)
def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
identity_if_possible=False, normalized=False, bias_init_value=None, alpha=1.0, identity_if_possible=False,
name=None, collections=None): normalized=False, name=None, collections=None):
"""Linear (affine) transformation, y = x W + b, for a variety of """Linear (affine) transformation, y = x W + b, for a variety of
configurations. configurations.
...@@ -110,6 +110,9 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, ...@@ -110,6 +110,9 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
if mat_init_value is not None and mat_init_value.shape != (in_size, out_size): if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
raise ValueError( raise ValueError(
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size)) 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if bias_init_value is not None and bias_init_value.shape != (1,out_size):
raise ValueError(
'Provided bias_init_value must have shape [1,%d].'%(1,out_size))
if mat_init_value is None: if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size)) stddev = alpha/np.sqrt(float(in_size))
...@@ -143,16 +146,20 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0, ...@@ -143,16 +146,20 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, alpha=1.0,
w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
collections=w_collections) collections=w_collections)
b = None
if do_bias: if do_bias:
b_collections = [tf.GraphKeys.GLOBAL_VARIABLES] b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
if collections: if collections:
b_collections += collections b_collections += collections
bname = (name + "/b") if name else "/b" bname = (name + "/b") if name else "/b"
b = tf.get_variable(bname, [1, out_size], if bias_init_value is None:
initializer=tf.zeros_initializer(), b = tf.get_variable(bname, [1, out_size],
collections=b_collections) initializer=tf.zeros_initializer(),
else: collections=b_collections)
b = None else:
b = tf.Variable(bias_init_value, name=bname,
collections=b_collections)
return (w, b) return (w, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment