Commit 4b0cec67 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317560325
parent eb355ec0
......@@ -173,3 +173,18 @@ def assert_rank(tensor, expected_rank, name=None):
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank)))
def safe_mean(losses):
"""Computes a safe mean of the losses.
Args:
losses: `Tensor` whose elements contain individual loss measurements.
Returns:
A scalar representing the mean of `losses`. If `num_present` is zero,
then zero is returned.
"""
total = tf.reduce_sum(losses)
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
return tf.math.divide_no_nan(total, num_elements)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment