"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "a697002cfbb4e7d2fb8a1a4646d958bc57a2d973"
Commit 44bdf29f authored by Allen Lavoie's avatar Allen Lavoie
Browse files

Shape fix for LFADS tf.case statements

parent a57a00f6
...@@ -432,11 +432,15 @@ class LFADS(object): ...@@ -432,11 +432,15 @@ class LFADS(object):
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
case_default = lambda: tf.constant([-8675309.0]) def _case_with_no_default(pairs):
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, case_default, exclusive=True) def _default_value_fn():
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, case_default, exclusive=True) with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, case_default, exclusive=True) return tf.identity(pairs[0][1]())
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, case_default, exclusive=True) return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = _case_with_no_default(pf_pairs_in_fac_Ws)
this_in_fac_b = _case_with_no_default(pf_pairs_in_fac_bs)
this_out_fac_W = _case_with_no_default(pf_pairs_out_fac_Ws)
this_out_fac_b = _case_with_no_default(pf_pairs_out_fac_bs)
# External inputs (not changing by dataset, by definition). # External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0: if hps.ext_input_dim > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment