"src/vscode:/vscode.git/clone" did not exist on "62c2c547dbc9eee39d4ddc310dbd477df20c754b"
Commit cee4b75e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 366120838
parent cdf815f8
......@@ -97,26 +97,31 @@ class RetinaNetModel(tf.keras.Model):
raw_scores, raw_boxes, raw_attributes = self.head(features)
if training:
return {
outputs = {
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
}
if raw_attributes:
outputs.update({'att_outputs': raw_attributes})
return outputs
else:
# Post-processing.
final_results = self.detection_generator(raw_boxes, raw_scores,
anchor_boxes, image_shape,
raw_attributes)
return {
final_results = self.detection_generator(
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
outputs = {
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'detection_attributes': final_results['detection_attributes'],
'num_detections': final_results['num_detections'],
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
}
if raw_attributes:
outputs.update({
'att_outputs': raw_attributes,
'detection_attributes': final_results['detection_attributes'],
})
return outputs
@property
def checkpoint_items(self):
......
......@@ -160,7 +160,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
if training:
cls_outputs = model_outputs['cls_outputs']
box_outputs = model_outputs['box_outputs']
att_outputs = model_outputs['att_outputs']
for level in range(min_level, max_level + 1):
self.assertIn(str(level), cls_outputs)
self.assertIn(str(level), box_outputs)
......@@ -177,6 +176,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
4 * num_anchors_per_location
], box_outputs[str(level)].numpy().shape)
if has_att_heads:
att_outputs = model_outputs['att_outputs']
for att in att_outputs.values():
self.assertAllEqual([
2, image_size[0] // 2**level, image_size[1] // 2**level,
......@@ -186,7 +186,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('detection_boxes', model_outputs)
self.assertIn('detection_scores', model_outputs)
self.assertIn('detection_classes', model_outputs)
self.assertIn('detection_attributes', model_outputs)
self.assertIn('num_detections', model_outputs)
self.assertAllEqual(
[2, 10, 4], model_outputs['detection_boxes'].numpy().shape)
......@@ -197,6 +196,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(
[2,], model_outputs['num_detections'].numpy().shape)
if has_att_heads:
self.assertIn('detection_attributes', model_outputs)
self.assertAllEqual(
[2, 10, 1],
model_outputs['detection_attributes']['depth'].numpy().shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment