Commit d5ae12e1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Modify CenterNet to output extracted features.

PiperOrigin-RevId: 400090738
parent e06d2c3a
...@@ -3805,6 +3805,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3805,6 +3805,7 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_dict: a dictionary holding predicted tensors with prediction_dict: a dictionary holding predicted tensors with
'preprocessed_inputs' - The input image after being resized and 'preprocessed_inputs' - The input image after being resized and
preprocessed by the feature extractor. preprocessed by the feature extractor.
'extracted_features' - The output of the feature extractor.
'object_center' - A list of size num_feature_outputs containing 'object_center' - A list of size num_feature_outputs containing
float tensors of size [batch_size, output_height, output_width, float tensors of size [batch_size, output_height, output_width,
num_classes] representing the predicted object center heatmap logits. num_classes] representing the predicted object center heatmap logits.
...@@ -3848,6 +3849,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3848,6 +3849,7 @@ class CenterNetMetaArch(model.DetectionModel):
predictions[head_name] = [ predictions[head_name] = [
head(feature) for (feature, head) in zip(features_list, heads) head(feature) for (feature, head) in zip(features_list, heads)
] ]
predictions['extracted_features'] = features_list
predictions['preprocessed_inputs'] = preprocessed_inputs predictions['preprocessed_inputs'] = preprocessed_inputs
self._batched_prediction_tensor_names = predictions.keys() self._batched_prediction_tensor_names = predictions.keys()
......
...@@ -3460,6 +3460,7 @@ class CenterNetMetaArch1dTest(test_case.TestCase, parameterized.TestCase): ...@@ -3460,6 +3460,7 @@ class CenterNetMetaArch1dTest(test_case.TestCase, parameterized.TestCase):
postprocess_output = arch.postprocess(predictions, true_shapes) postprocess_output = arch.postprocess(predictions, true_shapes)
losses_output = arch.loss(predictions, true_shapes) losses_output = arch.loss(predictions, true_shapes)
self.assertIn('extracted_features', predictions)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER), self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER),
losses_output) losses_output)
self.assertEqual((), losses_output['%s/%s' % ( self.assertEqual((), losses_output['%s/%s' % (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment