Commit 2a5f2a95 authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #1984 from allenlavoie/master

Shape fix for LFADS tf.case statements
parents 504bc193 fd6da687
...@@ -7,9 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v ...@@ -7,9 +7,7 @@ This code implements the model from the paper "[LFADS - Latent Factor Analysis v
The code is written in Python 2.7.6. You will also need: The code is written in Python 2.7.6. You will also need:
* **TensorFlow** version 1.1 ([install](http://tflearn.org/installation/)) - * **TensorFlow** version 1.2.1 ([install](https://www.tensorflow.org/install/)) -
there is an incompatibility with LFADS and TF v1.2, which we are in the
process of resolving
* **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them) * **NumPy, SciPy, Matplotlib** ([install SciPy stack](https://www.scipy.org/install.html), contains all of them)
* **h5py** ([install](https://pypi.python.org/pypi/h5py)) * **h5py** ([install](https://pypi.python.org/pypi/h5py))
......
...@@ -432,11 +432,15 @@ class LFADS(object): ...@@ -432,11 +432,15 @@ class LFADS(object):
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
case_default = lambda: tf.constant([-8675309.0]) def _case_with_no_default(pairs):
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, case_default, exclusive=True) def _default_value_fn():
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, case_default, exclusive=True) with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, case_default, exclusive=True) return tf.identity(pairs[0][1]())
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, case_default, exclusive=True) return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = _case_with_no_default(pf_pairs_in_fac_Ws)
this_in_fac_b = _case_with_no_default(pf_pairs_in_fac_bs)
this_out_fac_W = _case_with_no_default(pf_pairs_out_fac_Ws)
this_out_fac_b = _case_with_no_default(pf_pairs_out_fac_bs)
# External inputs (not changing by dataset, by definition). # External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0: if hps.ext_input_dim > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment