"profiler/include/profile_conv_bwd_data_impl.hpp" did not exist on "823657ed120144943b7db87c07fe3e647128db56"
Commit fb2b278d authored by Vishnu Banna's avatar Vishnu Banna
Browse files

nms ops used by detection generator

parent 0352c8f4
......@@ -27,10 +27,8 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
expected_shape = np.array([num_boxes, 4])
xywh_box = box_ops.yxyx_to_xcycwh(boxes)
yxyx_box = box_ops.xcycwh_to_yxyx(boxes)
xyxy_box = box_ops.xcycwh_to_xyxy(boxes)
self.assertAllEqual(tf.shape(xywh_box).numpy(), expected_shape)
self.assertAllEqual(tf.shape(yxyx_box).numpy(), expected_shape)
self.assertAllEqual(tf.shape(xyxy_box).numpy(), expected_shape)
@parameterized.parameters((1), (5), (7))
def test_ious(self, num_boxes):
......
"""A set of private math operations used to safely implement the yolo loss"""
import tensorflow as tf
import tensorflow.keras.backend as K
def rm_nan_inf(x, val=0.0):
"""remove nan and infinity
Args:
x: any `Tensor` of any type.
val: value to replace nan and infinity with.
Return:
a `Tensor` with nan and infinity removed.
"""
cond = tf.math.logical_or(tf.math.is_nan(x), tf.math.is_inf(x))
val = tf.cast(val, dtype=x.dtype)
x = tf.where(cond, val, x)
return x
def rm_nan(x, val=0.0):
"""remove nan and infinity.
Args:
x: any `Tensor` of any type.
val: value to replace nan.
Return:
a `Tensor` with nan removed.
"""
cond = tf.math.is_nan(x)
val = tf.cast(val, dtype=x.dtype)
x = tf.where(cond, val, x)
return x
def divide_no_nan(a, b):
"""Nan safe divide operation built to allow model compilation in tflite.
Args:
a: any `Tensor` of any type.
b: any `Tensor` of any type with the same shape as tensor a.
Return:
a `Tensor` representing a divided by b, with all nan values removed.
"""
zero = tf.cast(0.0, b.dtype)
return tf.where(b == zero, zero, a / b)
def mul_no_nan(x, y):
"""Nan safe multiply operation built to allow model compilation in tflite and
to allowing one tensor to mask another. Where ever x is zero the
multiplication is not computed and the value is replaced with a zero. This is
requred because 0 * nan = nan. This can make computation unstable in some
cases where the intended behavior is for zero to mean ignore.
Args:
x: any `Tensor` of any type.
y: any `Tensor` of any type with the same shape as tensor x.
Return:
a `Tensor` representing x times y, where x is used to safely mask the
tensor y.
"""
return tf.where(x == 0, tf.cast(0, x.dtype), x * y)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment