Unverified Commit bb6f092a authored by Yuhao Zhang's avatar Yuhao Zhang Committed by GitHub
Browse files

Guard the unsafe tf.log to prevent NAN (#8223)

parent 396fa8e8
...@@ -33,6 +33,9 @@ VAL_INTERVAL = 200 ...@@ -33,6 +33,9 @@ VAL_INTERVAL = 200
# How often to save a model checkpoint # How often to save a model checkpoint
SAVE_INTERVAL = 2000 SAVE_INTERVAL = 2000
# EPSILON to avoid NAN
EPSILON = 1e-9
# tf record data location: # tf record data location:
DATA_DIR = 'push/push_train' DATA_DIR = 'push/push_train'
...@@ -81,7 +84,7 @@ def peak_signal_to_noise_ratio(true, pred): ...@@ -81,7 +84,7 @@ def peak_signal_to_noise_ratio(true, pred):
Returns: Returns:
peak signal to noise ratio (PSNR) peak signal to noise ratio (PSNR)
""" """
return 10.0 * tf.log(1.0 / mean_squared_error(true, pred)) / tf.log(10.0) return 10.0 * (- tf.log(tf.maximum(mean_squared_error(true, pred), EPSILON))) / tf.log(10.0)
def mean_squared_error(true, pred): def mean_squared_error(true, pred):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment