Unverified Commit bb6f092a authored by Yuhao Zhang's avatar Yuhao Zhang Committed by GitHub
Browse files

Guard the unsafe tf.log to prevent NAN (#8223)

parent 396fa8e8
......@@ -33,6 +33,9 @@ VAL_INTERVAL = 200
# How often to save a model checkpoint
SAVE_INTERVAL = 2000
# EPSILON to avoid NAN
EPSILON = 1e-9
# tf record data location:
DATA_DIR = 'push/push_train'
......@@ -81,7 +84,7 @@ def peak_signal_to_noise_ratio(true, pred):
Returns:
peak signal to noise ratio (PSNR)
"""
return 10.0 * tf.log(1.0 / mean_squared_error(true, pred)) / tf.log(10.0)
return 10.0 * (- tf.log(tf.maximum(mean_squared_error(true, pred), EPSILON))) / tf.log(10.0)
def mean_squared_error(true, pred):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment