Commit cee4b75e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 366120838
parent cdf815f8
...@@ -97,26 +97,31 @@ class RetinaNetModel(tf.keras.Model): ...@@ -97,26 +97,31 @@ class RetinaNetModel(tf.keras.Model):
raw_scores, raw_boxes, raw_attributes = self.head(features) raw_scores, raw_boxes, raw_attributes = self.head(features)
if training: if training:
return { outputs = {
'cls_outputs': raw_scores, 'cls_outputs': raw_scores,
'box_outputs': raw_boxes, 'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
} }
if raw_attributes:
outputs.update({'att_outputs': raw_attributes})
return outputs
else: else:
# Post-processing. # Post-processing.
final_results = self.detection_generator(raw_boxes, raw_scores, final_results = self.detection_generator(
anchor_boxes, image_shape, raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
raw_attributes) outputs = {
return {
'detection_boxes': final_results['detection_boxes'], 'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'], 'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'], 'detection_classes': final_results['detection_classes'],
'detection_attributes': final_results['detection_attributes'],
'num_detections': final_results['num_detections'], 'num_detections': final_results['num_detections'],
'cls_outputs': raw_scores, 'cls_outputs': raw_scores,
'box_outputs': raw_boxes, 'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
} }
if raw_attributes:
outputs.update({
'att_outputs': raw_attributes,
'detection_attributes': final_results['detection_attributes'],
})
return outputs
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
...@@ -160,7 +160,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -160,7 +160,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
if training: if training:
cls_outputs = model_outputs['cls_outputs'] cls_outputs = model_outputs['cls_outputs']
box_outputs = model_outputs['box_outputs'] box_outputs = model_outputs['box_outputs']
att_outputs = model_outputs['att_outputs']
for level in range(min_level, max_level + 1): for level in range(min_level, max_level + 1):
self.assertIn(str(level), cls_outputs) self.assertIn(str(level), cls_outputs)
self.assertIn(str(level), box_outputs) self.assertIn(str(level), box_outputs)
...@@ -177,6 +176,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -177,6 +176,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
4 * num_anchors_per_location 4 * num_anchors_per_location
], box_outputs[str(level)].numpy().shape) ], box_outputs[str(level)].numpy().shape)
if has_att_heads: if has_att_heads:
att_outputs = model_outputs['att_outputs']
for att in att_outputs.values(): for att in att_outputs.values():
self.assertAllEqual([ self.assertAllEqual([
2, image_size[0] // 2**level, image_size[1] // 2**level, 2, image_size[0] // 2**level, image_size[1] // 2**level,
...@@ -186,7 +186,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -186,7 +186,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('detection_boxes', model_outputs) self.assertIn('detection_boxes', model_outputs)
self.assertIn('detection_scores', model_outputs) self.assertIn('detection_scores', model_outputs)
self.assertIn('detection_classes', model_outputs) self.assertIn('detection_classes', model_outputs)
self.assertIn('detection_attributes', model_outputs)
self.assertIn('num_detections', model_outputs) self.assertIn('num_detections', model_outputs)
self.assertAllEqual( self.assertAllEqual(
[2, 10, 4], model_outputs['detection_boxes'].numpy().shape) [2, 10, 4], model_outputs['detection_boxes'].numpy().shape)
...@@ -197,6 +196,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -197,6 +196,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
[2,], model_outputs['num_detections'].numpy().shape) [2,], model_outputs['num_detections'].numpy().shape)
if has_att_heads: if has_att_heads:
self.assertIn('detection_attributes', model_outputs)
self.assertAllEqual( self.assertAllEqual(
[2, 10, 1], [2, 10, 1],
model_outputs['detection_attributes']['depth'].numpy().shape) model_outputs['detection_attributes']['depth'].numpy().shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment