Commit f5830dbe authored by Vishnu Banna's avatar Vishnu Banna
Browse files

math ops update

parent 24ade5b8
...@@ -58,25 +58,4 @@ def divide_no_nan(a, b): ...@@ -58,25 +58,4 @@ def divide_no_nan(a, b):
Returns: Returns:
a `Tensor` representing a divided by b, with all nan values removed. a `Tensor` representing a divided by b, with all nan values removed.
""" """
zero = tf.cast(0.0, b.dtype) return a / (b + 1e-9)
return tf.where(b == zero, zero, a / b) \ No newline at end of file
def mul_no_nan(x, y):
"""Nan safe multiply operation.
Built to allow model compilation in tflite and
to allow one tensor to mask another. Where ever x is zero the
multiplication is not computed and the value is replaced with a zero. This is
required because 0 * nan = nan. This can make computation unstable in some
cases where the intended behavior is for zero to mean ignore.
Args:
x: any `Tensor` of any type.
y: any `Tensor` of any type with the same shape as tensor x.
Returns:
a `Tensor` representing x times y, where x is used to safely mask the
tensor y.
"""
return tf.where(x == 0, tf.cast(0, x.dtype), x * y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment