Commit 3ea57553 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398757677
parent da3f16f9
...@@ -199,7 +199,7 @@ class AnchorLabeler(object): ...@@ -199,7 +199,7 @@ class AnchorLabeler(object):
for k, v in gt_attributes.items(): for k, v in gt_attributes.items():
att_size = v.get_shape().as_list()[-1] att_size = v.get_shape().as_list()[-1]
att_mask = tf.tile(cls_mask, [1, att_size]) att_mask = tf.tile(cls_mask, [1, att_size])
att_targets[k] = self.target_gather(v, match_indices, att_mask, -1) att_targets[k] = self.target_gather(v, match_indices, att_mask, 0.0)
weights = tf.squeeze(tf.ones_like(gt_labels, dtype=tf.float32), -1) weights = tf.squeeze(tf.ones_like(gt_labels, dtype=tf.float32), -1)
box_weights = self.target_gather(weights, match_indices, mask) box_weights = self.target_gather(weights, match_indices, mask)
......
...@@ -155,7 +155,7 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase): ...@@ -155,7 +155,7 @@ class AnchorTest(parameterized.TestCase, tf.test.TestCase):
att_targets[attribute_name][k] = v.numpy() att_targets[attribute_name][k] = v.numpy()
anchor_locations = np.vstack( anchor_locations = np.vstack(
np.where( np.where(
att_targets[attribute_name][str(min_level)] > -1)).transpose() att_targets[attribute_name][str(min_level)] > 0.0)).transpose()
self.assertAllClose(expected_anchor_locations, anchor_locations) self.assertAllClose(expected_anchor_locations, anchor_locations)
else: else:
self.assertEmpty(att_targets) self.assertEmpty(att_targets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment