Commit 8c5c60ca authored by Dan O'Shea's avatar Dan O'Shea
Browse files

Removing defaultless case statement

parent 32daecc9
......@@ -490,15 +490,15 @@ class LFADS(object):
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
def _case_with_no_default(pairs):
def _default_value_fn():
with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
return tf.identity(pairs[0][1]())
return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = _case_with_no_default(pf_pairs_in_fac_Ws)
this_in_fac_b = _case_with_no_default(pf_pairs_in_fac_bs)
this_out_fac_W = _case_with_no_default(pf_pairs_out_fac_Ws)
this_out_fac_b = _case_with_no_default(pf_pairs_out_fac_bs)
# def _case_with_no_default(pairs):
# def _default_value_fn():
# with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
# return tf.identity(pairs[0][1]())
# return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True)
# External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment