Unverified Commit 56e7526f authored by Yuhao Zhang's avatar Yuhao Zhang Committed by GitHub
Browse files

Guard the unsafe tf.exp to prevent Inf cost (#8221)

* Guard the unsafe tf.exp to prevent Inf cost

* Update run_lfads.py

Add the hyperparameter `_clip_value` to resolve the issue in the prototype example
parent 3f124e27
......@@ -792,7 +792,7 @@ class LFADS(object):
if hps.output_dist == 'poisson':
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
log_rates_t.set_shape([None, None])
rates[t] = dist_params[t] = tf.exp(log_rates_t) # rates feed back
rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) # rates feed back
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
......@@ -803,7 +803,7 @@ class LFADS(object):
value=mean_n_logvars)
rates[t] = means_t_bxd # rates feed back to controller
dist_params[t] = tf.concat(
axis=1, values=[means_t_bxd, tf.exp(logvars_t_bxd)])
axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))])
loglikelihood_t = \
diag_gaussian_log_likelihood(data_t_bxd,
means_t_bxd, logvars_t_bxd)
......
......@@ -577,6 +577,7 @@ def build_hyperparameter_dict(flags):
d['kl_increase_steps'] = flags.kl_increase_steps
d['l2_start_step'] = flags.l2_start_step
d['l2_increase_steps'] = flags.l2_increase_steps
d['_clip_value'] = 80 # bounds the tf.exp to avoid INF
return d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment